From 130b97ea0c1d9b9a4211434dd39857c79b4ed292 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Tue, 29 Mar 2022 13:46:14 +0100 Subject: [PATCH 01/60] partial fix --- Cargo.lock | 154 +++--- shotover-proxy/Cargo.toml | 3 +- shotover-proxy/src/frame/cassandra.rs | 193 +++++--- shotover-proxy/src/message/mod.rs | 81 +++- shotover-proxy/src/transforms/protect/mod.rs | 267 ++++++---- shotover-proxy/src/transforms/redis/cache.rs | 483 ++++++++----------- 6 files changed, 697 insertions(+), 484 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d23e2f8b3..376123415 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -83,9 +83,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.52" +version = "0.1.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "061a7acccaa286c011ddc30970520b98fa40e00c9d644633fb26b5fc63a265e3" +checksum = "ed6aa3524a2dfcf9fe180c51eae2b58738348d819517ceadf95789c51fff7600" dependencies = [ "proc-macro2", "quote", @@ -225,9 +225,9 @@ checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8" [[package]] name = "bytes-utils" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e314712951c43123e5920a446464929adc667a5eade7f8fb3997776c9df6e54" +checksum = "1934a3ef9cac8efde4966a92781e77713e1ba329f1d42e446c7d7eba340d8ef1" dependencies = [ "bytes", "either", @@ -304,7 +304,7 @@ dependencies = [ "num", "snap", "thiserror", - "time 0.3.7", + "time 0.3.9", "uuid", ] @@ -428,13 +428,24 @@ checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc" [[package]] name = "cpufeatures" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95059428f66df56b63431fdb4e1947ed2190586af5c5a8a8b71122bdf5a7f469" +checksum = "59a6001667ab124aebae2a495118e11d30984c3a653e99d86d58971708cf5e4b" dependencies = [ "libc", ] +[[package]] +name = "cql3_parser" +version = "0.0.1" +source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#464bf95b8dfbe5ee7d4ab5d91af07871aca17507" +dependencies = [ + "itertools", + "regex", + "tree-sitter", + "tree-sitter-cql", +] + [[package]] name = "crc16" version = "0.4.0" @@ -488,9 +499,9 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.5.2" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e54ea8bc3fb1ee042f5aace6e3c6e025d3874866da222930f70ce62aceba0bfa" +checksum = "5aaa7bd5fb665c6864b5f963dd9097905c54125909c7aa94c9e18507cdbe6c53" dependencies = [ "cfg-if", "crossbeam-utils", @@ -509,10 +520,11 @@ dependencies = [ [[package]] name = "crossbeam-epoch" -version = "0.9.7" +version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c00d6d2ea26e8b151d99093005cb442fb9a37aeaca582a03ec70946f49ab5ed9" +checksum = "1145cf131a2c6ba0615079ab6a638f7e1973ac9c2634fcbeaaad6114246efe8c" dependencies = [ + "autocfg", "cfg-if", "crossbeam-utils", "lazy_static", @@ -522,9 +534,9 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.7" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e5bed1f1c269533fa816a0a5492b3545209a205ca1a54842be180eb63a16a6" +checksum = "0bf124c720b7686e3c2663cf54062ab0f68a88af2fb6a030e87e30bf721fcb38" dependencies = [ "cfg-if", "lazy_static", @@ -659,9 +671,9 @@ checksum = "56899898ce76aaf4a0f24d914c97ea6ed976d42fec6ad33fcbb0a1103e07b2b0" [[package]] name = "ed25519" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eed12bbf7b5312f8da1c2722bc06d8c6b12c2d86a7fb35a194c7f3e6fc2bbe39" +checksum = "3d5c4b5e5959dc2c2b89918d8e2cc40fcdd623cef026ed09d2f0ee05199dc8e4" dependencies = [ "signature", ] @@ -893,9 +905,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d39cd93900197114fa1fcb7ae84ca742095eed9442088988ae74fa744e930e77" +checksum = "9be70c98951c83b8d2f8f60d7065fa6d5146873094452a1008da8c2f1e4205ad" dependencies = [ "cfg-if", "libc", @@ -951,9 +963,9 @@ checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" [[package]] name = "halfbrown" -version = "0.1.13" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49e26621a30b9fdb4f949b9c6a7fa42ce88112851c33ac4ca00bfa7848d26fb4" +checksum = "102968a036b06654b555049d9a6c4f46046805d1e1b22647720e93e0704d4c60" dependencies = [ "hashbrown 0.12.0", ] @@ -1058,9 +1070,9 @@ dependencies = [ [[package]] name = "hyper" -version = "0.14.17" +version = "0.14.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "043f0e083e9901b6cc658a77d1eb86f4fc650bbb977a4337dd63192826aa85dd" +checksum = "b26ae0a80afebe130861d90abf98e3814a4f28a4c6ffeb5ab8ebb2be311e0ef2" dependencies = [ "bytes", "futures-channel", @@ -1186,9 +1198,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.120" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad5c14e80759d0939d013e6ca49930e59fc53dd8e5009132f76240c179380c09" +checksum = "efaa7b300f3b5fe8eb6bf21ce3895e1751d9665086af2d64b42f19701015ff4f" [[package]] name = "libloading" @@ -1235,9 +1247,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.14" +version = "0.4.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710" +checksum = "6389c490849ff5bc16be905ae24bc913a9c8892e19b2341dbc175e14c341c2b8" dependencies = [ "cfg-if", ] @@ -1381,9 +1393,9 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ba42135c6a5917b9db9cd7b293e5409e1c6b041e6f9825e92e55a894c63b6f8" +checksum = "52da4364ffb0e4fe33a9841a98a3f3014fb964045ce4f7a45a398243c8d6b0c9" dependencies = [ "libc", "log", @@ -1404,9 +1416,9 @@ dependencies = [ [[package]] name = "native-tls" -version = "0.2.8" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48ba9f7719b5a0f42f338907614285fb5fd70e53858141f69898a1fb7203b24d" +checksum = "fd7e2f3618557f980e0b17e8856252eee3c97fa12c54dff0ca290fb6266ca4a9" dependencies = [ "lazy_static", "libc", @@ -1600,9 +1612,9 @@ dependencies = [ [[package]] name = "num_threads" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c539a50b93a303167eded6e8dff5220cd39447409fb659f4cd24b1f72fe4f133" +checksum = "aba1801fb138d8e85e11d0fc70baf4fe1cdfffda7c6cd34a854905df588e5ed0" dependencies = [ "libc", ] @@ -1656,9 +1668,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-src" -version = "111.17.0+1.1.1m" +version = "111.18.0+1.1.1n" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05d6a336abd10814198f66e2a91ccd7336611f30334119ca8ce300536666fcf4" +checksum = "7897a926e1e8d00219127dc020130eca4292e5ca666dd592480d72c3eca2ff6c" dependencies = [ "cc", ] @@ -1927,9 +1939,9 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quote" -version = "1.0.15" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "864d3e96a899863136fc6e99f3d7cae289dafe43bf2c5ac19b70df7210c0a145" +checksum = "632d02bff7f874a36f33ea8bb416cd484b90cc66c1194b1a1110d067a7013f58" dependencies = [ "proc-macro2", ] @@ -1982,9 +1994,9 @@ dependencies = [ [[package]] name = "raw-cpuid" -version = "10.2.0" +version = "10.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "929f54e29691d4e6a9cc558479de70db7aa3d98cd6fe7ab86d7507aa2886b9d2" +checksum = "738bc47119e3eeccc7e94c4a506901aea5e7b4944ecd0829cbebf4af04ceda12" dependencies = [ "bitflags", ] @@ -2052,21 +2064,22 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.2.11" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8380fe0152551244f0747b1bf41737e0f8a74f97a14ccefd1148187271634f3c" +checksum = "8ae183fc1b06c149f0c1793e1eb447c8b04bfe46d48e9e48bfb8d2d7ed64ecf0" dependencies = [ "bitflags", ] [[package]] name = "redox_users" -version = "0.4.0" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "528532f3d801c87aec9def2add9ca802fe569e44a544afe633765267840abe64" +checksum = "7776223e2696f1aa4c6b0170e83212f47296a00424305117d013dfe86fb0fe55" dependencies = [ "getrandom", "redox_syscall", + "thiserror", ] [[package]] @@ -2309,9 +2322,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.6" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a3381e03edd24287172047536f20cabde766e2cd3e65e6b00fb3af51c4f38d" +checksum = "d65bd28f48be7196d222d95b9243287f48d27aca604e08497513019ff0502cc4" [[package]] name = "serde" @@ -2461,6 +2474,7 @@ dependencies = [ "cassandra-cpp", "cassandra-protocol", "clap 3.1.6", + "cql3_parser", "crc16", "criterion", "csv", @@ -2506,7 +2520,7 @@ dependencies = [ "tokio-io-timeout", "tokio-openssl", "tokio-stream", - "tokio-util 0.7.0", + "tokio-util 0.7.1", "tracing", "tracing-appender", "tracing-log", @@ -2589,9 +2603,9 @@ dependencies = [ [[package]] name = "sqlparser" -version = "0.15.0" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adbbea2526ad0d02ad9414a07c396078a5b944bbf9ca4fbab8f01bb4cb579081" +checksum = "b8f192f29f4aa49e57bebd0aa05858e0a1f32dd270af36efe49edb82cbfffab6" dependencies = [ "log", ] @@ -2639,9 +2653,9 @@ checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" [[package]] name = "syn" -version = "1.0.88" +version = "1.0.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebd69e719f31e88618baa1eaa6ee2de5c9a1c004f1e9ecdb58e8352a13f20a01" +checksum = "704df27628939572cd88d33f171cd6f896f4eaca85252c6e0a72d8d8287ee86f" dependencies = [ "proc-macro2", "quote", @@ -2753,9 +2767,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.7" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "004cbc98f30fa233c61a38bc77e96a9106e65c88f2d3bef182ae952027e5753d" +checksum = "c2702e08a7a860f005826c6815dcac101b19b5eb330c27fe4a5928fec1d20ddd" dependencies = [ "itoa 1.0.1", "libc", @@ -2765,9 +2779,9 @@ dependencies = [ [[package]] name = "time-macros" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25eb0ca3468fc0acc11828786797f6ef9aa1555e4a211a60d64cc8e4d1be47d6" +checksum = "42657b1a6f4d817cda8e7a0ace261fe0cc946cf3a80314390b22cc61ae080792" [[package]] name = "tinytemplate" @@ -2898,19 +2912,19 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64910e1b9c1901aaf5375561e35b9c057d95ff41a44ede043a03e09279eabaf1" +checksum = "0edfdeb067411dba2044da6d1cb2df793dd35add7888d73c16e3381ded401764" dependencies = [ "bytes", "futures-core", "futures-io", "futures-sink", "futures-util", - "log", "pin-project-lite", "slab", "tokio", + "tracing", ] [[package]] @@ -2933,12 +2947,12 @@ dependencies = [ [[package]] name = "tracing-appender" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ab026b18a46ac429e5c98bec10ca06424a97b3ad7b3949d9b4a102fff6623c4" +checksum = "09d48f71a791638519505cefafe162606f706c25592e4bde4d97600c0195312e" dependencies = [ "crossbeam-channel", - "time 0.3.7", + "time 0.3.9", "tracing-subscriber", ] @@ -2993,6 +3007,26 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "tree-sitter" +version = "0.20.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09b3b781640108d29892e8b9684642d2cda5ea05951fd58f0fea1db9edeb9b71" +dependencies = [ + "cc", + "regex", +] + +[[package]] +name = "tree-sitter-cql" +version = "0.0.1" +source = "git+https://github.com/Claude-at-Instaclustr/tree-sitter-cql?branch=main#495dffa341b8342312abd895a2bc4b3d316db23a" +dependencies = [ + "cc", + "regex", + "tree-sitter", +] + [[package]] name = "try-lock" version = "0.2.3" @@ -3302,6 +3336,6 @@ dependencies = [ [[package]] name = "zeroize" -version = "1.5.3" +version = "1.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50344758e2f40e3a1fcfc8f6f91aa57b5f8ebd8d27919fe6451f15aaaf9ee608" +checksum = "7eb5728b8afd3f280a869ce1d4c554ffaed35f45c231fc41bfbd0381bef50317" diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml index 6964a27d6..c6389f6c9 100644 --- a/shotover-proxy/Cargo.toml +++ b/shotover-proxy/Cargo.toml @@ -41,7 +41,8 @@ thiserror = "1.0" anyhow = "1.0.31" # Parsers -sqlparser = "0.15" +sqlparser = "0.14" +cql3_parser = { git = "https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git", branch = "main" } serde = { version = "1.0.111", features = ["derive"] } serde_json = "1.0" serde_yaml = "0.8.21" diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index d491df85f..1dcc89123 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -1,4 +1,6 @@ -use anyhow::{anyhow, Result}; +use std::net::IpAddr; +use std::str::FromStr; +use anyhow::anyhow; use bytes::Bytes; use cassandra_protocol::compression::Compression; use cassandra_protocol::consistency::Consistency; @@ -16,15 +18,18 @@ use cassandra_protocol::frame::{ Direction, Flags, Frame as RawCassandraFrame, Opcode, Serialize, StreamId, Version, }; use cassandra_protocol::query::{QueryParams, QueryValues}; -use cassandra_protocol::types::{CBytes, CBytesShort, CInt, CLong}; +use cassandra_protocol::types::{AsCassandraType, CBytes, CBytesShort, CInt, CLong}; use itertools::Itertools; use nonzero_ext::nonzero; -use sqlparser::ast::{SetExpr, Statement, TableFactor}; -use sqlparser::dialect::GenericDialect; -use sqlparser::parser::Parser; use std::convert::TryInto; use std::num::NonZeroU32; use std::slice::IterMut; +use cassandra_protocol::types::blob::Blob; +use cassandra_protocol::types::cassandra_type::CassandraType; +use cql3_parser::cassandra_ast::CassandraAST; +use cql3_parser::cassandra_statement::CassandraStatement; +use cql3_parser::common::Operand; +use sodiumoxide::hex; use uuid::Uuid; use crate::message::{MessageValue, QueryType}; @@ -75,6 +80,10 @@ pub(crate) fn cell_count(bytes: &[u8]) -> Result { _ => nonzero!(1u32), }) } +======= +// TODO remove this and use actual default from session. +const DEFAULT_KEYSPACE : &str = ""; +>>>>>>> partial fix #[derive(PartialEq, Debug, Clone)] pub struct CassandraFrame { @@ -235,17 +244,56 @@ impl CassandraFrame { } pub fn get_query_type(&self) -> QueryType { + /* + Read, + Write, + ReadWrite, + SchemaChange, + PubSubMessage, + */ match &self.operation { CassandraOperation::Query { - query: CQL::Parsed(query), + query: cql, .. - } => match query.get(0) { - Some(Statement::Query(_x)) => QueryType::Read, - Some(Statement::Insert { .. }) => QueryType::Write, - Some(Statement::Update { .. }) => QueryType::Write, - Some(Statement::Delete { .. }) => QueryType::Write, - // TODO: handle prepared, execute and schema change query types - _ => QueryType::Read, + } => match cql.statement.get(0).unwrap() { + CassandraStatement::AlterKeyspace(_) => QueryType::SchemaChange, + CassandraStatement::AlterMaterializedView(_) => QueryType::SchemaChange, + CassandraStatement::AlterRole(_) => QueryType::SchemaChange, + CassandraStatement::AlterTable(_) => QueryType::SchemaChange, + CassandraStatement::AlterType(_) => QueryType::SchemaChange, + CassandraStatement::AlterUser(_) => QueryType::SchemaChange, + CassandraStatement::ApplyBatch => QueryType::ReadWrite, + CassandraStatement::CreateAggregate(_) => QueryType::SchemaChange, + CassandraStatement::CreateFunction(_) => QueryType::SchemaChange, + CassandraStatement::CreateIndex(_) => QueryType::SchemaChange, + CassandraStatement::CreateKeyspace(_) => QueryType::SchemaChange, + CassandraStatement::CreateMaterializedView(_) => QueryType::SchemaChange, + CassandraStatement::CreateRole(_) => QueryType::SchemaChange, + CassandraStatement::CreateTable(_) => QueryType::SchemaChange, + CassandraStatement::CreateTrigger(_) => QueryType::SchemaChange, + CassandraStatement::CreateType(_) => QueryType::SchemaChange, + CassandraStatement::CreateUser(_) => QueryType::SchemaChange, + CassandraStatement::DeleteStatement(_) => QueryType::Write, + CassandraStatement::DropAggregate(_) => QueryType::SchemaChange, + CassandraStatement::DropFunction(_) => QueryType::SchemaChange, + CassandraStatement::DropIndex(_) => QueryType::SchemaChange, + CassandraStatement::DropKeyspace(_) => QueryType::SchemaChange, + CassandraStatement::DropMaterializedView(_) => QueryType::SchemaChange, + CassandraStatement::DropRole(_) => QueryType::SchemaChange, + CassandraStatement::DropTable(_) => QueryType::SchemaChange, + CassandraStatement::DropTrigger(_) => QueryType::SchemaChange, + CassandraStatement::DropType(_) => QueryType::SchemaChange, + CassandraStatement::DropUser(_) => QueryType::SchemaChange, + CassandraStatement::Grant(_) => QueryType::SchemaChange, + CassandraStatement::Insert(_) => QueryType::Write, + CassandraStatement::ListPermissions(_) => QueryType::Read, + CassandraStatement::ListRoles(_) => QueryType::Read, + CassandraStatement::Revoke(_) => QueryType::SchemaChange, + CassandraStatement::Select(_) => QueryType::Read, + CassandraStatement::Truncate( _) => QueryType::Write, + CassandraStatement::Update(_) => QueryType::Write, + CassandraStatement::Use(_) => QueryType::SchemaChange, + CassandraStatement::Unknown(_) => QueryType::Read, }, _ => QueryType::Read, } @@ -254,33 +302,9 @@ impl CassandraFrame { pub fn namespace(&self) -> Vec { match &self.operation { CassandraOperation::Query { - query: CQL::Parsed(query), + query: cql, .. - } => match query.first() { - Some(Statement::Query(query)) => match &query.body { - SetExpr::Select(select) => { - if let TableFactor::Table { name, .. } = - &select.from.get(0).unwrap().relation - { - name.0.iter().map(|a| a.value.clone()).collect() - } else { - vec![] - } - } - _ => vec![], - }, - Some(Statement::Insert { table_name, .. }) - | Some(Statement::Delete { table_name, .. }) => { - table_name.0.iter().map(|a| a.value.clone()).collect() - } - Some(Statement::Update { table, .. }) => match &table.relation { - TableFactor::Table { name, .. } => { - name.0.iter().map(|a| a.value.clone()).collect() - } - _ => vec![], - }, - _ => vec![], - }, + } => cql.statement.iter().map( |x|x.get_keyspace( DEFAULT_KEYSPACE )).collect(), _ => vec![], } } @@ -463,34 +487,95 @@ impl CassandraOperation { } #[derive(PartialEq, Debug, Clone)] -pub enum CQL { - Parsed(Vec), - FailedToParse(String), +pub struct CQL { + pub statement : Vec, + pub has_error : Vec } impl CQL { pub fn to_query_string(&self) -> String { - match self { - CQL::Parsed(ast) => ast.iter().map(|x| x.to_string()).join(""), - CQL::FailedToParse(str) => str.clone(), + self.statement.get(0).unwrap().to_string() + } + + pub fn parse_from_string(cql: String) -> Self { + let ast = CassandraAST::new(cql ); + CQL { + statement : vec![ast.statement], + has_error : vec![ast.has_error()], } } +} - pub fn parse_from_string(sql: String) -> Self { - match Parser::parse_sql(&GenericDialect, &sql) { - _ if sql.contains("ALTER TABLE") || sql.contains("CREATE TABLE") => { - tracing::error!("Failed to parse CQL for frame {:?}\nError: Blacklisted query as sqlparser crate cant round trip it", sql); - CQL::FailedToParse(sql) +pub trait ToCassandraType { + fn from_string_value(&self, value : &str) -> Option; + fn as_cassandra_type(&self) -> Option; +} + +impl ToCassandraType for Operand { + fn from_string_value(&self, value : &str ) -> Option { + // check for string types + if value.starts_with("'") || value.starts_with("$$") { + Some(CassandraType::Varchar(value.to_string())) + } else if value.starts_with("0X") || value.starts_with("X'") { + let mut chars = value.chars(); + chars.next(); + chars.next(); + let bytes = hex::decode(chars.as_str()).unwrap(); + Some(CassandraType::Blob(Blob::from(bytes))) + } else { + let num = i64::from_str(value); + if num.is_ok() { + Some(CassandraType::Bigint(num.unwrap())) + } else { + let num = f64::from_str(value); + if num.is_ok() { + Some(CassandraType::Double(num.unwrap())) + } else { + let uuid = Uuid::parse_str(value); + if uuid.is_ok() { + Some(CassandraType::Uuid(uuid.unwrap())) + } else { + let ipaddr = IpAddr::from_str(value); + if ipaddr.is_ok() { + Some(CassandraType::Inet(ipaddr.unwrap())) + } + None + } + } + } + } + } + + + fn as_cassandra_type(&self) -> Option { + match self { + Operand::Const(value) => { + self.from_string_value( value ) + } + Operand::Map(values) => { + CassandraType::Map(values.iter().map( |(key,value)| (self.from_string_value( key), self.from_string_value(value)) ).collect()) + } + Operand::Set( values) => { + CassandraType::Set(values.iter().map( |value| self.from_string_value(value) ).collect()) + } + Operand::List( values) => { + CassandraType::List(values.iter().map( |value| self.from_string_value(value) ).collect()) } - Ok(ast) => CQL::Parsed(ast), - Err(err) => { - tracing::error!("Failed to parse CQL for frame {:?}\nError: {:?}", sql, err); - CQL::FailedToParse(sql) + Operand::Tuple( values) => { + CassandraType::Tuple( values.iter().map( |o| o.as_cassandra_type()).collect()) } + Operand::Column(value) => { + CassandraType::Ascii( value.to_string() ) + } + Operand::Func(value) => { + CassandraType::Ascii( value.to_string() ) + } + Operand::Null => CassandraType::Null, } } } + #[derive(PartialEq, Debug, Clone)] pub enum CassandraResult { Rows { diff --git a/shotover-proxy/src/message/mod.rs b/shotover-proxy/src/message/mod.rs index 13f24d546..83a768c66 100644 --- a/shotover-proxy/src/message/mod.rs +++ b/shotover-proxy/src/message/mod.rs @@ -1,8 +1,6 @@ use crate::codec::redis::redis_query_type; -use crate::frame::{ cassandra, - cassandra::{CassandraMetadata, CassandraOperation}, -}; + cassandra::{CassandraMetadata, CassandraOperation, ToCassandraType}; use crate::frame::{CassandraFrame, Frame, MessageType, RedisFrame}; use anyhow::{anyhow, Result}; use bigdecimal::BigDecimal; @@ -24,11 +22,15 @@ use nonzero_ext::nonzero; use num::BigInt; use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; -use sqlparser::ast::Value as SQLValue; use std::collections::{BTreeMap, BTreeSet}; use std::net::IpAddr; use std::num::NonZeroU32; use uuid::Uuid; +use std::str::FromStr; +use cql3_parser::common::{DataTypeName, Operand}; +use cql3_parser::common::DataTypeName::Ascii; +use sodiumoxide::hex; + enum Metadata { Cassandra(CassandraMetadata), @@ -36,6 +38,7 @@ enum Metadata { None, } + pub type Messages = Vec; /// The Message type is designed to effeciently abstract over the message being in various states of processing. @@ -473,7 +476,7 @@ impl From<&SQLValue> for MessageValue { match v { SQLValue::Number(v, false) | SQLValue::SingleQuotedString(v) - | SQLValue::NationalStringLiteral(v) => MessageValue::Strings(v.clone()), + | SQLValue::NationalStringLiteral(v) => MessageValue::Strings(v.to_string()), SQLValue::HexStringLiteral(v) => MessageValue::Strings(v.to_string()), SQLValue::Boolean(v) => MessageValue::Boolean(*v), _ => MessageValue::Strings("NULL".to_string()), @@ -481,6 +484,72 @@ impl From<&SQLValue> for MessageValue { } } +impl From<&MessageValue> for Operand { + fn from(v: &MessageValue) -> Self { + match v { + MessageValue::NULL => Operand::Null, + MessageValue::Bytes(b) => Operand::Const( format!("0X{}", b.encode_hex() )), + MessageValue::Strings(s) => Operand::Const(format!("'{}'", s)), + MessageValue::Integer(i, _) => Operand::Const(i.to_string()), + MessageValue::Float(f) => Operand::Const(f.to_string()), + MessageValue::Boolean(b) => Operand::Const( if b { "TRUE".to_string() } else {"FALSE".to_string()}), + + _ => {} + } + } +} + +impl From<&Operand> for MessageValue { + + fn from(operand: &Operand) -> Self { + operand.as_cassandra_type().map_or( MessageValue::None, |x| MessageValue::create_element( x )) + } +} + +impl From<&MessageValue> for DataTypeName { + fn from(v: &MessageValue) -> Self { + match v { + MessageValue::Bytes(_) => DataTypeName::Blob, + MessageValue::Ascii(_) => DataTypeName::Ascii, + MessageValue::Strings(_) => DataTypeName::Text, + MessageValue::Integer(_, size) => { + match size { + //DataTypeName::Int + IntSize::I64 => DataTypeName::BigInt, + IntSize::I32 => DataTypeName::Int, + IntSize::I16 => DataTypeName::SmallInt, + IntSize::I8 => DataTypeName::TinyInt, + } + }, + MessageValue::Double(_) => DataTypeName::Double, + MessageValue::Float(_) => DataTypeName::Float, + MessageValue::Boolean(_) => DataTypeName::Boolean, + MessageValue::Inet(_) => DataTypeName::Inet, + MessageValue::List(_) => DataTypeName::List, + MessageValue::Rows(_) => DataTypeName::List, + MessageValue::NamedRows(_) => DataTypeName::Tuple, + MessageValue::Document(_) => DataTypeName::Tuple, + MessageValue::FragmentedResponse(_) => DataTypeName::Tuple, + MessageValue::Set(_) => DataTypeName::Set, + MessageValue::Map(_) => DataTypeName::Map, + MessageValue::Varint(_) => DataTypeName::VarInt, + MessageValue::Decimal(_) => DataTypeName::Decimal, + MessageValue::Date(_) => DataTypeName::Date, + MessageValue::Timestamp(_) => DataTypeName::Timestamp, + MessageValue::Timeuuid(_) => DataTypeName::TimeUuid, + MessageValue::Varchar(_) => DataTypeName::VarChar, + MessageValue::Uuid(_) => DataTypeName::Uuid, + MessageValue::Time(_) => DataTypeName::Time, + MessageValue::Counter(_) => DataTypeName::Counter, + MessageValue::Tuple(_) => DataTypeName::Tuple, + MessageValue::Udt(_) => DataTypeName::Tuple, + MessageValue::NULL => {}, + None => {}, + } + } + +} + impl From for MessageValue { fn from(f: RedisFrame) -> Self { match f { @@ -574,7 +643,7 @@ impl MessageValue { wrapper(data, col_type).unwrap() } - fn create_element(element: CassandraType) -> MessageValue { + pub fn create_element(element: CassandraType) -> MessageValue { match element { CassandraType::Ascii(a) => MessageValue::Ascii(a), CassandraType::Bigint(b) => MessageValue::Integer(b, IntSize::I64), diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index 38a907fd0..7bdd1c44d 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -1,5 +1,5 @@ use crate::error::ChainResponse; -use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; +use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, CQL, Frame}; use crate::message::MessageValue; use crate::transforms::protect::key_management::{KeyManager, KeyManagerConfig}; use crate::transforms::{Transform, Transforms, Wrapper}; @@ -11,11 +11,19 @@ use itertools::Itertools; use serde::{Deserialize, Serialize}; use sodiumoxide::crypto::secretbox; use sodiumoxide::crypto::secretbox::{Key, Nonce}; -use sqlparser::ast::{Assignment, Expr, Ident, Query, SetExpr, Statement, Value as SQLValue}; +//use sqlparser::ast::{Assignment, Expr, Ident, Query, SetExpr, Statement, Value as SQLValue}; use std::borrow::BorrowMut; use std::collections::HashMap; +use cql3_parser::cassandra_statement::CassandraStatement; +use cql3_parser::common::Operand; +use cql3_parser::insert::InsertValues; +use cql3_parser::select::SelectElement; +use cql3_parser::update::AssignmentElement; +use serde_yaml::seed::from_slice_seed; +use sodiumoxide::hex; use tracing::warn; + mod aws_kms; mod key_management; mod local_kek; @@ -90,6 +98,20 @@ impl From for MessageValue { } } +impl From for Operand { + fn from(p: Protected) -> Self { + match p { + Protected::Plaintext(_) => panic!( + "tried to move unencrypted value to plaintext without explicitly calling decrypt" + ), + Protected::Ciphertext { .. } => { + Operand::Const( format!( "0X{}", + hex::encode(serde_json::to_vec(&p).unwrap()))) + } + } + } +} + impl Protected { pub async fn from_encrypted_bytes_value(value: &MessageValue) -> Result { match value { @@ -156,44 +178,91 @@ impl ProtectConfig { } } -pub fn get_values_from_insert_or_update_mut(ast: &mut Statement) -> HashMap { - match ast { - Statement::Insert { - source, columns, .. - } => get_values_from_insert_mut(columns.as_mut(), source.borrow_mut()), - Statement::Update { assignments, .. } => get_values_from_update_mut(assignments.as_mut()), - _ => HashMap::new(), - } -} +pub fn get_values_from_insert_or_update_mut(ast: &mut CQL) -> HashMap { + match ast.statement[0] { + Some(stmt) => + match stmt { + CassandraStatement::Insert(insert) => { + match insert.values { + InsertValues::Values(values) => { + let mut result = HashMap::new(); + // if the lengths don't match we will return an empty hashmap. + if values.len() == insert.columns.len() { + for (i, value) in values.iter().enumerate() { + if let Operand::Const(val) = value { + result.insert( insert.columns[i].to_string(), value); + } + } + } else { + // TODO do we need to clear data here? + } + result + } + // TODO parse JSON? + InsertValues::Json(_) => HashMap::new() + } + } -fn get_values_from_insert_mut<'a>( - columns: &'a mut [Ident], - source: &'a mut Query, -) -> HashMap { - let mut map = HashMap::new(); - let mut columns_iter = columns.iter(); - if let SetExpr::Values(v) = &mut source.body { - for value in &mut v.0 { - for ex in value { - if let Expr::Value(v) = ex { - if let Some(c) = columns_iter.next() { - map.insert(c.value.to_string(), v); + CassandraStatement::Update(update) => { + let mut result = HashMap::new(); + for assignment in update.assignments { + // the operator adds something like +x or -x to the assignment so it indicates this is not a value + // and thus we should skip it. + if assignment.operator.is_none() { + if let Operand::Const(val) = &assignment.value { + result.insert( assignment.name.to_string(), assignment.value); + } + } else { + // TODO do we need to clear data here? + } } + result } + + _ => HashMap::new() } - } + + _ => HashMap::new() } - map } -fn get_values_from_update_mut(assignments: &mut [Assignment]) -> HashMap { - let mut map = HashMap::new(); - for assignment in assignments { - if let Expr::Value(v) = &mut assignment.value { - map.insert(assignment.id.iter().map(|x| &x.value).join("."), v); +/// determines if columns in the CassandraStatement need to be encrypted and encrypts them. Returns `true` if any columns were changed. +#[async_trait] +fn encrypt_columns( statement : &mut CassandraStatement, columns : &Vec, key_source : &KeyManager, key_id : &str) -> bool { + + let mut data_changed = false; + match statement { + CassandraStatement::Insert(insert) => { + let indices = insert.columns.iter().enumerate().filter( |(i,col_name)| columns.contains( col_name )) + .map( |(i,col_name)| i).collect(); + match &mut insert.values { + InsertValues::Values(operands) => { + for i in indices { + let operand = operands[i].unwrap(); + let mut protected = Protected::Plaintext( MessageValue::from(operand )); + protected = protected.protect( key_source, key_id ).await?; + std::mem::replace( &mut operands[i], Operand::from(protected)); + data_changed = true; + } + } + InsertValues::Json(_) => { + // TODO parse json and encrypt. + } + } + } + CassandraStatement::Update(update) => { + for assignment in &mut update.assignments { + if columns.contains( &assignment.name.column ) { + let mut protected = Protected::Plaintext( MessageValue::from(&assignment.value) ); + protected = protected.protect( key_source, key_id ).await?; + assignment.value = Operand::from( protected ); + data_changed = true; + } + } } + _ => {} } - map + data_changed } #[async_trait] @@ -201,40 +270,27 @@ impl Transform for Protect { async fn transform<'a>(&'a mut self, mut message_wrapper: Wrapper<'a>) -> ChainResponse { // encrypt the values included in any INSERT or UPDATE queries for message in message_wrapper.messages.iter_mut() { - let mut invalidate_cache = false; + let mut data_changed = false; if let Some(namespace) = message.namespace() { if namespace.len() == 2 { - if let Some(Frame::Cassandra(frame)) = message.frame() { - if let Ok(queries) = frame.operation.queries() { - for query in queries { - if let Some((_, tables)) = - self.keyspace_table_columns.get_key_value(&namespace[0]) - { - if let Some((_, columns)) = tables.get_key_value(&namespace[1]) - { - let mut values = - get_values_from_insert_or_update_mut(query); - for col in columns { - if let Some(value) = values.get_mut(col) { - let mut protected = Protected::Plaintext( - MessageValue::from(&**value), - ); - protected = protected - .protect(&self.key_source, &self.key_id) - .await?; - **value = - SQLValue::from(&MessageValue::from(protected)); - invalidate_cache = true; - } - } - } + if let Some(Frame::Cassandra(CassandraFrame { + operation: CassandraOperation::Query { query, .. }, + .. + })) = message.frame() + { + if let Some((_, tables)) = + self.keyspace_table_columns.get_key_value(&namespace[0]) + { + if let Some((_, columns)) = tables.get_key_value(&namespace[1]) { + for mut stmt in query.statement { + data_changed = encrypt_columns(&mut stmt, columns, &self.key_source, &self.key_id ) } } } } } } - if invalidate_cache { + if data_changed { message.invalidate_cache(); } } @@ -254,34 +310,78 @@ impl Transform for Protect { })) = response.frame() { if let Some(namespace) = request.namespace() { - if let Some(Frame::Cassandra(frame)) = request.frame() { - if let Ok(queries) = frame.operation.queries() { - for query in queries { - let projection: Vec = - get_values_from_insert_or_update_mut(query) - .into_keys() - .collect(); - if namespace.len() == 2 { - if let Some((_keyspace, tables)) = - self.keyspace_table_columns.get_key_value(&namespace[0]) - { - if let Some((_table, protect_columns)) = - tables.get_key_value(&namespace[1]) - { - let mut positions: Vec = Vec::new(); - for (i, p) in projection.iter().enumerate() { - if protect_columns.contains(p) { - positions.push(i); - } - } - for row in rows.iter_mut() { - for index in &mut positions { - if let Some(v) = row.get_mut(*index) { + if let Some(Frame::Cassandra(CassandraFrame { + operation: CassandraOperation::Query { query, .. }, + .. + })) = request.frame() + { + if namespace.len() == 2 { + if let Some((_keyspace, tables)) = + self.keyspace_table_columns.get_key_value(&namespace[0]) + { + if let Some((_table, protect_columns)) = + tables.get_key_value(&namespace[1]) + { + for cassandra_statement in query.statement { + if let CassandraStatement::Select(select) = cassandra_statement { + let positions : Vec = select.columns.iter().enumerate() + .filter_map( | (i,col)| { + if let SelectElement::Column(named) = col { + if protect_columns.contains(&named.name) + { + Some(i) + } else { + None + } + } else { + None + } + }).collect(); + for row in rows { + for index in positions { + if let Some(v) = row.get_mut(index) { if let MessageValue::Bytes(_) = v { let protected = - Protected::from_encrypted_bytes_value( - v, - ) + Protected::from_encrypted_bytes_value(v) + .await?; + let new_value: MessageValue = protected + .unprotect(&self.key_source, &self.key_id) + .await?; + *v = new_value; + invalidate_cache = true; + } else { + warn!("Tried decrypting non-blob column") + } + } + } + } + } + } + } + } + } + let projection: Vec = get_values_from_insert_or_update_mut(query) + .into_keys() + .collect(); + if namespace.len() == 2 { + if let Some((_keyspace, tables)) = + self.keyspace_table_columns.get_key_value(&namespace[0]) + { + if let Some((_table, protect_columns)) = + tables.get_key_value(&namespace[1]) + { + let mut positions: Vec = Vec::new(); + for (i, p) in projection.iter().enumerate() { + if protect_columns.contains(p) { + positions.push(i); + } + } + for row in rows { + for index in &mut positions { + if let Some(v) = row.get_mut(*index) { + if let MessageValue::Bytes(_) = v { + let protected = + Protected::from_encrypted_bytes_value(v) .await?; let new_value: MessageValue = protected .unprotect( @@ -317,3 +417,4 @@ impl Transform for Protect { Ok(result) } } + diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index 9dab2f975..4f3aa851a 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -1,6 +1,6 @@ use crate::config::topology::TopicHolder; use crate::error::ChainResponse; -use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame, RedisFrame}; +use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, CQL, Frame, RedisFrame}; use crate::message::{Message, MessageValue, Messages, QueryType}; use crate::transforms::chain::TransformChain; use crate::transforms::{ @@ -15,6 +15,9 @@ use serde::Deserialize; use sqlparser::ast::{Assignment, BinaryOperator, Expr, Ident, Query, SetExpr, Statement, Value}; use std::borrow::Borrow; use std::collections::HashMap; +use cql3_parser::cassandra_statement::CassandraStatement; +use cql3_parser::common::{Operand, PrimaryKey, RelationElement, RelationOperator, WhereClause}; +use tracing::info; const TRUE: [u8; 1] = [0x1]; const FALSE: [u8; 1] = [0x0]; @@ -31,6 +34,13 @@ pub struct TableCacheSchema { range_key: Vec, } + +impl From<&PrimaryKey> for TableCacheSchema { + fn from(value: &PrimaryKey) -> TableCacheSchema { + TableCacheSchema { partition_key: value.partition.clone(), range_key: value.clustering.clone() } + } +} + impl RedisConfig { pub async fn get_transform(&self, topics: &TopicHolder) -> Result { Ok(Transforms::RedisCache(SimpleRedisCache { @@ -52,10 +62,7 @@ impl SimpleRedisCache { "SimpleRedisCache" } - async fn get_or_update_from_cache( - &mut self, - mut messages_cass_request: Messages, - ) -> ChainResponse { + async fn get_or_update_from_cache(&mut self, mut messages: Messages) -> ChainResponse { // This function is a little hard to follow, so heres an overview. // We have 4 vecs of messages, each vec can be considered its own stage of processing. // 1. messages_cass_request: @@ -73,22 +80,28 @@ impl SimpleRedisCache { // - we can get away with this because batches can only contain INSERT/UPDATE/DELETE and therefore always contain either an ERROR or a VOID RESULT // + if the request is a CassandraOperation::Query then we consume a single message from messages_redis_response converting it to a cassandra response // * These are the cassandra responses that we return from the function. - - let mut messages_redis_request = Vec::with_capacity(messages_cass_request.len()); - for cass_request in &mut messages_cass_request { - if let Some(table_name) = cass_request.namespace().map(|x| x.join(".")) { - match cass_request.frame() { - Some(Frame::Cassandra(frame)) => { - for query in frame.operation.queries()? { - let table_cache_schema = self - .caching_schema - .get(&table_name) - .ok_or_else(|| anyhow!("{table_name} not a caching table"))?; - - messages_redis_request.push(Message::from_frame(Frame::Redis( - build_redis_ast_from_sql(query, table_cache_schema)?, - ))); - } + let mut stream_ids = Vec::with_capacity(messages.len()); + for message in &mut messages { + if let Some(Frame::Cassandra(frame)) = message.frame() { + stream_ids.push(frame.stream_id); + } else { + bail!("Failed to parse cassandra message"); + } + if let Some(table_name) = message.namespace().map(|x| x.join(".")) { + *message = match message.frame() { + Some(Frame::Cassandra(CassandraFrame { + operation: CassandraOperation::Query { query, .. }, + .. + })) => { + let table_cache_schema = self + .caching_schema + .get(&table_name) + .ok_or_else(|| anyhow!("{table_name} not a caching table"))?; + + Message::from_frame(Frame::Redis(build_redis_ast_from_cql3( + query, + table_cache_schema, + )?)) } message => bail!("cannot fetch {message:?} from cache"), } @@ -131,29 +144,6 @@ impl SimpleRedisCache { } } -// TODO: We don't need to do it this way and allocate another struct -struct ValueHelper(Value); - -impl ValueHelper { - fn as_bytes(&self) -> &[u8] { - match &self.0 { - Value::Number(v, false) => v.as_bytes(), - Value::SingleQuotedString(v) => v.as_bytes(), - Value::NationalStringLiteral(v) => v.as_bytes(), - Value::HexStringLiteral(v) => v.as_bytes(), - Value::Boolean(v) => { - if *v { - &TRUE - } else { - &FALSE - } - } - Value::Null => &[], - _ => unreachable!(), - } - } -} - fn append_prefix_min(min: &mut Vec) { if min.is_empty() { min.push(b'['); @@ -170,286 +160,219 @@ fn append_prefix_max(max: &mut Vec) { } } -fn build_zrangebylex_min_max_from_sql( - expr: &Expr, - pks: &[String], +fn build_zrangebylex_min_max_from_cql3( + operator : &RelationOperator, + operand: &Operand, min: &mut Vec, max: &mut Vec, ) -> Result<()> { - match expr { - Expr::BinaryOp { left, op, right } => { - // first check if this is a related to PK - if let Expr::Identifier(i) = left.borrow() { - if pks.iter().any(|v| *v == i.value) { - //Ignore this as we build the pk constraint elsewhere - return Ok(()); - } - } - - match op { - BinaryOperator::Gt => { - // we shift the value for Gt so that it works with other GtEq operators - if let Expr::Value(v) = right.borrow() { - let vh = ValueHelper(v.clone()); - - let mut minrv = Vec::from(vh.as_bytes()); - let last_byte = minrv.last_mut().unwrap(); - *last_byte += 1; - - append_prefix_min(min); - min.extend(minrv.iter()); - } - } - BinaryOperator::Lt => { - // we shift the value for Lt so that it works with other LtEq operators - if let Expr::Value(v) = right.borrow() { - let vh = ValueHelper(v.clone()); - - let mut maxrv = Vec::from(vh.as_bytes()); - let last_byte = maxrv.last_mut().unwrap(); - *last_byte -= 1; - - append_prefix_max(max); - max.extend(maxrv.iter()); - } - } - BinaryOperator::GtEq => { - if let Expr::Value(v) = right.borrow() { - let vh = ValueHelper(v.clone()); - - let minrv = Vec::from(vh.as_bytes()); - append_prefix_min(min); - min.extend(minrv.iter()); - } - } - BinaryOperator::LtEq => { - if let Expr::Value(v) = right.borrow() { - let vh = ValueHelper(v.clone()); - - let maxrv = Vec::from(vh.as_bytes()); - - append_prefix_max(max); - max.extend(maxrv.iter()); - } - } - BinaryOperator::Eq => { - if let Expr::Value(v) = right.borrow() { - let vh = ValueHelper(v.clone()); - - let vh_bytes = vh.as_bytes(); - - append_prefix_min(min); - append_prefix_max(max); - min.extend(vh_bytes.iter()); - max.extend(vh_bytes.iter()); - } - } - BinaryOperator::And => { - build_zrangebylex_min_max_from_sql(left, pks, min, max)?; - build_zrangebylex_min_max_from_sql(right, pks, min, max)?; - } - _ => { - return Err(anyhow!("Couldn't build query")); - } + let mut bytes = + Vec::from( match operand { + Operand::Const(value) => { + match value.to_uppercase().as_str() { + "TRUE" => &TRUE, + "FALSE" => &FALSE, + _ => value.as_bytes(), } } - _ => { + Operand::Map(_) | + Operand::Set(_) | + Operand::List(_) | + Operand::Tuple(_) | + Operand::Column(_) | + Operand::Func(_) => operand.to_string().as_bytes(), + Operand::Null => &[], + }); + + match operator { + RelationOperator::LessThan => { + let last_byte = bytes.last_mut().unwrap(); + *last_byte -= 1; + + append_prefix_max(max); + max.extend(bytes.iter()); + } + RelationOperator::LessThanOrEqual => { + append_prefix_max(max); + max.extend(bytes.iter()); + } + + RelationOperator::Equal => { + append_prefix_min(min); + append_prefix_max(max); + min.extend(bytes.iter()); + max.extend(bytes.iter()); + } + RelationOperator::GreaterThanOrEqual => { + append_prefix_min(min); + min.extend(bytes.iter()); + } + RelationOperator::GreaterThan => { + let last_byte = bytes.last_mut().unwrap(); + *last_byte += 1; + append_prefix_min(min); + min.extend(bytes.iter()); + } + // should "IN"" be converted to an "or" "eq" combination + + RelationOperator::NotEqual | + RelationOperator::In | + RelationOperator::Contains | + RelationOperator::ContainsKey | + RelationOperator::IsNot => { return Err(anyhow!("Couldn't build query")); } } Ok(()) } -fn build_redis_ast_from_sql( - ast: &Statement, +fn build_redis_frames_from_where_clause( where_clause : &Vec, table_cache_schema: &TableCacheSchema) -> Vec { + let mut min: Vec = Vec::new(); + let mut max: Vec = Vec::new(); + + where_clause.iter().filter_map(|relation_element| + { + match &relation_element.obj { + Operand::Column(name) => { + if table_cache_schema.partition_key.contains(name) { + Some((&relation_element.oper,&relation_element.value)) + } else { None } + }, + _ => None + } + }).for_each(|(operator,values)| { + for operand in values { + build_zrangebylex_min_max_from_cql3( operator,operand, &mut min, &mut max, ); + } + }); + + let min = if min.is_empty() { + Bytes::from_static(b"-") + } else { + Bytes::from(min) + }; + let max = if max.is_empty() { + Bytes::from_static(b"+") + } else { + Bytes::from(max) + }; + + let where_columns = WhereClause::get_column_relation_element_map( where_clause ); + let pk = table_cache_schema + .partition_key + .iter() + .filter_map(|k| { + let x = where_columns.get(k); + if x.is_none() { + return None + } + let y = x.iter().filter(|x| x.oper == RelationOperator::Equal).nth(0); + if y.is_none() { + return None + } + Some(y.value) + }) + .fold(BytesMut::new(), |mut acc, v| { + if let Some(v) = v { + v.iter().for_each(|vv| acc.extend(MessageValue::from( vv ).into_str_bytes())); + } + acc + }); + vec![ + RedisFrame::BulkString("ZRANGEBYLEX".into()), + RedisFrame::BulkString(pk.freeze()), + RedisFrame::BulkString(min), + RedisFrame::BulkString(max), + ] +} +fn build_redis_ast_from_cql3 ( + ast: &CQL, table_cache_schema: &TableCacheSchema, -) -> Result { - match ast { - Statement::Query(q) => match &q.body { - SetExpr::Select(s) if s.selection.is_some() => { - let expr = s.selection.as_ref().unwrap(); - let mut min: Vec = Vec::new(); - let mut max: Vec = Vec::new(); - - build_zrangebylex_min_max_from_sql( - expr, - &table_cache_schema.partition_key, - &mut min, - &mut max, - )?; - - let min = if min.is_empty() { - Bytes::from_static(b"-") - } else { - Bytes::from(min) - }; - let max = if max.is_empty() { - Bytes::from_static(b"+") +) -> Result +{ + match &ast.statement[0] { + CassandraStatement::Select(select) => { + if select.where_clause.is_some() { + Ok(RedisFrame::Array( build_redis_frames_from_where_clause( &select.where_clause.unwrap(),table_cache_schema))) } else { - Bytes::from(max) - }; - + Err(anyhow!("Cant build query from statement: {}", &ast.statement[0])) + } + } + CassandraStatement::Insert(insert) => { + let value_map : HashMap = insert.get_value_map(); let pk = table_cache_schema .partition_key .iter() - .map(|k| get_equal_value_from_expr(expr, k)) + .map(|k| value_map.get(k.as_str()).unwrap()) .fold(BytesMut::new(), |mut acc, v| { - if let Some(v) = v { - acc.extend(MessageValue::from(v).into_str_bytes()); - } + acc.extend(MessageValue::from(*v).into_str_bytes()); acc }); - - let commands_buffer = vec![ - RedisFrame::BulkString("ZRANGEBYLEX".into()), + let mut redis_frames: Vec = vec![ + RedisFrame::BulkString("ZADD".into()), RedisFrame::BulkString(pk.freeze()), - RedisFrame::BulkString(min), - RedisFrame::BulkString(max), ]; - Ok(RedisFrame::Array(commands_buffer)) + + let mut map = HashMap::new(); + value_map.iter().for_each(|(key,value)| {map.insert( key, MessageValue::from( value ));}); + Ok(RedisFrame::Array(add_values_to_redis_frames(table_cache_schema, map, redis_frames))) } - expr => Err(anyhow!("Can't build query from expr: {}", expr)), - }, - Statement::Insert { - source, columns, .. - } => { - let query_values = get_values_from_insert(columns, source); - - let pk = table_cache_schema - .partition_key - .iter() - .map(|k| query_values.get(k.as_str()).unwrap()) - .fold(BytesMut::new(), |mut acc, v| { - acc.extend(MessageValue::from(*v).into_str_bytes()); - acc - }); - - insert_or_update(table_cache_schema, query_values, pk) - } - Statement::Update { - assignments, - selection, - .. - } => { - let query_values = get_values_from_update(assignments); - - let pk = table_cache_schema - .partition_key - .iter() - .map(|k| get_equal_value_from_expr(selection.as_ref().unwrap(), k).unwrap()) - .fold(BytesMut::new(), |mut acc, v| { - acc.extend(MessageValue::from(v).into_str_bytes()); - acc - }); - - insert_or_update(table_cache_schema, query_values, pk) + CassandraStatement::Update(update) => { + let mut redis_frames = build_redis_frames_from_where_clause( &update.where_clause, table_cache_schema); + let mut map = HashMap::new(); + for x in update.assignments { + // skip any columns with +/- modifiers. + if x.operator.is_none() { + map.insert(x.name.to_string(), MessageValue::from(x.value)) + } + } + + Ok(RedisFrame::Array(add_values_to_redis_frames(table_cache_schema, map, redis_frames ))) + + } + + statement => Err(anyhow!("Cant build query from statement: {}", statement)), } - statement => Err(anyhow!("Cant build query from statement: {}", statement)), - } } -fn insert_or_update( +fn add_values_to_redis_frames( table_cache_schema: &TableCacheSchema, - query_values: HashMap, - pk: BytesMut, -) -> Result { - let mut commands_buffer: Vec = vec![ - RedisFrame::BulkString("ZADD".into()), - RedisFrame::BulkString(pk.freeze()), - ]; + query_values: HashMap, + mut redis_frames : Vec +) -> Vec { let clustering = table_cache_schema .range_key .iter() .map(|k| query_values.get(k.as_str()).unwrap()) .fold(BytesMut::new(), |mut acc, v| { - acc.extend(MessageValue::from(*v).into_str_bytes()); + acc.extend(v.into_str_bytes()); acc }); - let values = query_values + query_values .iter() .filter_map(|(p, v)| { - if table_cache_schema.partition_key.iter().all(|x| x != p) - && table_cache_schema.range_key.iter().all(|x| x != p) + if !(table_cache_schema.partition_key.contains(p) || + table_cache_schema.range_key.contains( p)) { - Some(MessageValue::from(*v)) - } else { None + } else { + Some(v) } }) - .collect_vec(); - - for v in values { - commands_buffer.push(RedisFrame::BulkString(Bytes::from_static(b"0"))); - let mut value = clustering.clone(); - if !value.is_empty() { - value.put_u8(b':'); - } - value.extend(v.clone().into_str_bytes()); - commands_buffer.push(RedisFrame::BulkString(value.freeze())); - } - - Ok(RedisFrame::Array(commands_buffer)) -} - -fn get_values_from_insert<'a>( - columns: &'a [Ident], - source: &'a Query, -) -> HashMap { - let mut map = HashMap::new(); - let mut columns_iter = columns.iter(); - if let SetExpr::Values(v) = &source.body { - for value in &v.0 { - for ex in value { - if let Expr::Value(v) = ex { - if let Some(c) = columns_iter.next() { - // TODO: We should be able to avoid allocation here - map.insert(c.value.to_string(), v); - } - } + .for_each( |v| { + redis_frames.push(RedisFrame::BulkString(Bytes::from_static(b"0"))); + let mut value = clustering.clone(); + if !value.is_empty() { + value.put_u8(b':'); } - } - } - map -} - -fn get_values_from_update(assignments: &[Assignment]) -> HashMap { - let mut map = HashMap::new(); - for assignment in assignments { - if let Expr::Value(v) = &assignment.value { - map.insert(assignment.id.iter().map(|x| &x.value).join("."), v); - } - } - map -} + value.extend(v.into_str_bytes()); + redis_frames.push(RedisFrame::BulkString(value.freeze())); + }); -fn get_equal_value_from_expr<'a>(expr: &'a Expr, find_identifier: &str) -> Option<&'a Value> { - if let Expr::BinaryOp { left, op, right } = expr { - match op { - BinaryOperator::And => get_equal_value_from_expr(left, find_identifier) - .or_else(|| get_equal_value_from_expr(right, find_identifier)), - BinaryOperator::Eq => { - if let Expr::Identifier(i) = left.borrow() { - if i.value == find_identifier { - if let Expr::Value(v) = right.borrow() { - Some(v) - } else { - None - } - } else { - None - } - } else { - None - } - } - _ => None, - } - } else { - None - } + redis_frames } #[async_trait] From 4aaa6e48558e102e4dd46aa033510742162cf26d Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Wed, 30 Mar 2022 12:46:27 +0100 Subject: [PATCH 02/60] partial fixes --- Cargo.lock | 11 +- shotover-proxy/src/codec/cassandra.rs | 23 +-- shotover-proxy/src/frame/cassandra.rs | 50 +++---- shotover-proxy/src/message/mod.rs | 96 +++++++------ shotover-proxy/src/transforms/protect/mod.rs | 135 +++--------------- .../src/transforms/query_counter.rs | 15 +- shotover-proxy/src/transforms/redis/cache.rs | 103 ++++++------- 7 files changed, 164 insertions(+), 269 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 376123415..6a7c57443 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -438,12 +438,17 @@ dependencies = [ [[package]] name = "cql3_parser" version = "0.0.1" -source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#464bf95b8dfbe5ee7d4ab5d91af07871aca17507" +source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#bf4f943c300a382e428eedbd42db9f497145cf55" dependencies = [ + "bigdecimal", + "bytes", + "hex", "itertools", + "num", "regex", "tree-sitter", "tree-sitter-cql", + "uuid", ] [[package]] @@ -1124,9 +1129,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "282a6247722caba404c065016bbfa522806e51714c34f5dfc3e4a3a46fcb4223" +checksum = "0f647032dfaa1f8b6dc29bd3edb7bbef4861b8b8007ebb118d6db284fd59f6ee" dependencies = [ "autocfg", "hashbrown 0.11.2", diff --git a/shotover-proxy/src/codec/cassandra.rs b/shotover-proxy/src/codec/cassandra.rs index 063a82b0f..485114d1c 100644 --- a/shotover-proxy/src/codec/cassandra.rs +++ b/shotover-proxy/src/codec/cassandra.rs @@ -143,12 +143,6 @@ mod cassandra_protocol_tests { use cassandra_protocol::frame::Version; use cassandra_protocol::query::QueryParams; use hex_literal::hex; - use sqlparser::ast::Expr::BinaryOp; - use sqlparser::ast::Value::SingleQuotedString; - use sqlparser::ast::{ - BinaryOperator, Expr, Ident, ObjectName, Query, Select, SelectItem, SetExpr, Statement, - TableFactor, TableWithJoins, Value as SQLValue, Values, - }; use tokio_util::codec::{Decoder, Encoder}; fn test_frame_codec_roundtrip( @@ -371,7 +365,12 @@ mod cassandra_protocol_tests { tracing_id: None, warnings: vec![], operation: CassandraOperation::Query { - query: CQL::Parsed(vec![Statement::Query(Box::new(Query { + query: CQL{ statement: vec![ + CassandraAst::parse( "Select * from system where key = 'local'") + ], has_error: vec![false] }, + /* + + CQL::Parsed(vec![Statement::Query(Box::new(Query { with: None, body: SetExpr::Select(Box::new(Select { distinct: false, @@ -417,7 +416,8 @@ mod cassandra_protocol_tests { offset: None, fetch: None, lock: None, - }))]), + params: Default::default() + }))]),*/ params: QueryParams::default(), }, }))]; @@ -438,7 +438,10 @@ mod cassandra_protocol_tests { tracing_id: None, warnings: vec![], operation: CassandraOperation::Query { - query: CQL::Parsed(vec![Statement::Insert { + query: CQL{ statement: vec![ + CassandraAst::parse( "Select bar from foo") + ], has_error: vec![false] }, + /*CQL::Parsed(vec![Statement::Insert { or: None, table_name: ObjectName(vec![ Ident { @@ -470,7 +473,7 @@ mod cassandra_protocol_tests { after_columns: (vec![]), table: false, on: None, - }]), + }]),*/ params: QueryParams::default(), }, }))]; diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index 1dcc89123..3c616688f 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -1,6 +1,6 @@ use std::net::IpAddr; use std::str::FromStr; -use anyhow::anyhow; +use anyhow::{anyhow, Result}; use bytes::Bytes; use cassandra_protocol::compression::Compression; use cassandra_protocol::consistency::Consistency; @@ -19,7 +19,6 @@ use cassandra_protocol::frame::{ }; use cassandra_protocol::query::{QueryParams, QueryValues}; use cassandra_protocol::types::{AsCassandraType, CBytes, CBytesShort, CInt, CLong}; -use itertools::Itertools; use nonzero_ext::nonzero; use std::convert::TryInto; use std::num::NonZeroU32; @@ -29,6 +28,7 @@ use cassandra_protocol::types::cassandra_type::CassandraType; use cql3_parser::cassandra_ast::CassandraAST; use cql3_parser::cassandra_statement::CassandraStatement; use cql3_parser::common::Operand; +use itertools::Itertools; use sodiumoxide::hex; use uuid::Uuid; @@ -80,10 +80,8 @@ pub(crate) fn cell_count(bytes: &[u8]) -> Result { _ => nonzero!(1u32), }) } -======= // TODO remove this and use actual default from session. const DEFAULT_KEYSPACE : &str = ""; ->>>>>>> partial fix #[derive(PartialEq, Debug, Clone)] pub struct CassandraFrame { @@ -124,7 +122,7 @@ impl CassandraFrame { Opcode::Query => { if let RequestBody::Query(body) = frame.request_body()? { CassandraOperation::Query { - query: CQL::parse_from_string(body.query), + query: CQL::parse_from_string(&body.query), params: body.query_params, } } else { @@ -212,7 +210,7 @@ impl CassandraFrame { .map(|query| BatchStatement { ty: match query.subject { BatchQuerySubj::QueryString(query) => { - BatchStatementType::Statement(CQL::parse_from_string(query)) + BatchStatementType::Statement(CQL::parse_from_string(&query)) } BatchQuerySubj::PreparedId(id) => { BatchStatementType::PreparedId(id) @@ -347,16 +345,12 @@ pub enum CassandraOperation { impl CassandraOperation { /// Return all queries contained within CassandaOperation::Query and CassandraOperation::Batch /// An Err is returned if the operation cannot contain queries or the queries failed to parse. - pub fn queries(&mut self) -> Result> { + pub fn queries(&mut self) -> Result> { match self { CassandraOperation::Query { - query: CQL::Parsed(query), - .. - } => Ok(query.iter_mut()), - CassandraOperation::Query { - query: CQL::FailedToParse(_), + query: cql, .. - } => Err(anyhow!("Couldnt parse query")), + } => Ok(cql.statement.iter_mut()), // TODO: Return CassandraOperation::Batch queries once we add BATCH parsing to cassandra-protocol _ => Err(anyhow!("This operation cannot contain queries")), } @@ -489,19 +483,19 @@ impl CassandraOperation { #[derive(PartialEq, Debug, Clone)] pub struct CQL { pub statement : Vec, - pub has_error : Vec + pub has_error : bool, } impl CQL { pub fn to_query_string(&self) -> String { - self.statement.get(0).unwrap().to_string() + self.statement.iter().join( ";" ) } - pub fn parse_from_string(cql: String) -> Self { - let ast = CassandraAST::new(cql ); + pub fn parse_from_string(cql_query_str: &str) -> Self { + let ast = CassandraAST::new(cql_query_str ); CQL { - statement : vec![ast.statement], - has_error : vec![ast.has_error()], + statement : ast.statements, + has_error : ast.has_error(), } } } @@ -538,39 +532,39 @@ impl ToCassandraType for Operand { let ipaddr = IpAddr::from_str(value); if ipaddr.is_ok() { Some(CassandraType::Inet(ipaddr.unwrap())) + } else { + None } - None } } } } } - fn as_cassandra_type(&self) -> Option { match self { Operand::Const(value) => { self.from_string_value( value ) } Operand::Map(values) => { - CassandraType::Map(values.iter().map( |(key,value)| (self.from_string_value( key), self.from_string_value(value)) ).collect()) + Some(CassandraType::Map(values.iter().map( |(key,value)| (self.from_string_value( key).unwrap(), self.from_string_value(value).unwrap()) ).collect())) } Operand::Set( values) => { - CassandraType::Set(values.iter().map( |value| self.from_string_value(value) ).collect()) + Some(CassandraType::Set(values.iter().filter_map( |value| self.from_string_value(value) ).collect())) } Operand::List( values) => { - CassandraType::List(values.iter().map( |value| self.from_string_value(value) ).collect()) + Some(CassandraType::List(values.iter().filter_map( |value| self.from_string_value(value) ).collect())) } Operand::Tuple( values) => { - CassandraType::Tuple( values.iter().map( |o| o.as_cassandra_type()).collect()) + Some(CassandraType::Tuple( values.iter().filter_map( |value| value.as_cassandra_type()).collect())) } Operand::Column(value) => { - CassandraType::Ascii( value.to_string() ) + Some(CassandraType::Ascii( value.to_string() )) } Operand::Func(value) => { - CassandraType::Ascii( value.to_string() ) + Some(CassandraType::Ascii( value.to_string() )) } - Operand::Null => CassandraType::Null, + Operand::Null => Some(CassandraType::Null), } } } diff --git a/shotover-proxy/src/message/mod.rs b/shotover-proxy/src/message/mod.rs index 83a768c66..dcd67bcd3 100644 --- a/shotover-proxy/src/message/mod.rs +++ b/shotover-proxy/src/message/mod.rs @@ -1,6 +1,8 @@ use crate::codec::redis::redis_query_type; +use crate::frame::{ cassandra, - cassandra::{CassandraMetadata, CassandraOperation, ToCassandraType}; + cassandra::{CassandraMetadata, CassandraOperation, ToCassandraType}, +}; use crate::frame::{CassandraFrame, Frame, MessageType, RedisFrame}; use anyhow::{anyhow, Result}; use bigdecimal::BigDecimal; @@ -26,11 +28,7 @@ use std::collections::{BTreeMap, BTreeSet}; use std::net::IpAddr; use std::num::NonZeroU32; use uuid::Uuid; -use std::str::FromStr; use cql3_parser::common::{DataTypeName, Operand}; -use cql3_parser::common::DataTypeName::Ascii; -use sodiumoxide::hex; - enum Metadata { Cassandra(CassandraMetadata), @@ -455,52 +453,64 @@ pub enum IntSize { I8, // Tinyint } -impl From<&MessageValue> for SQLValue { - fn from(v: &MessageValue) -> Self { - match v { - MessageValue::NULL => SQLValue::Null, - MessageValue::Bytes(b) => { - SQLValue::SingleQuotedString(String::from_utf8(b.to_vec()).unwrap()) - } // TODO: this is definitely wrong - MessageValue::Strings(s) => SQLValue::SingleQuotedString(s.clone()), - MessageValue::Integer(i, _) => SQLValue::Number(i.to_string(), false), - MessageValue::Float(f) => SQLValue::Number(f.to_string(), false), - MessageValue::Boolean(b) => SQLValue::Boolean(*b), - _ => SQLValue::Null, - } - } -} - -impl From<&SQLValue> for MessageValue { - fn from(v: &SQLValue) -> Self { - match v { - SQLValue::Number(v, false) - | SQLValue::SingleQuotedString(v) - | SQLValue::NationalStringLiteral(v) => MessageValue::Strings(v.to_string()), - SQLValue::HexStringLiteral(v) => MessageValue::Strings(v.to_string()), - SQLValue::Boolean(v) => MessageValue::Boolean(*v), - _ => MessageValue::Strings("NULL".to_string()), - } - } -} - impl From<&MessageValue> for Operand { fn from(v: &MessageValue) -> Self { match v { MessageValue::NULL => Operand::Null, - MessageValue::Bytes(b) => Operand::Const( format!("0X{}", b.encode_hex() )), - MessageValue::Strings(s) => Operand::Const(format!("'{}'", s)), - MessageValue::Integer(i, _) => Operand::Const(i.to_string()), - MessageValue::Float(f) => Operand::Const(f.to_string()), - MessageValue::Boolean(b) => Operand::Const( if b { "TRUE".to_string() } else {"FALSE".to_string()}), + MessageValue::Bytes(b) => Operand::from(b) , + MessageValue::Ascii( s ) | + MessageValue::Varchar(s) | + MessageValue::Strings(s) => Operand::from(s.as_str()), + MessageValue::Integer(i, _) => Operand::from(i), + MessageValue::Float(f) => Operand::from(&f.0), + MessageValue::Boolean(b) => Operand::from(b), + MessageValue::Double(d) => Operand::from(&d.0), + MessageValue::Inet(i) => Operand::from( i ), + MessageValue::Varint( i ) => Operand::from(i), + MessageValue::Decimal(d ) => Operand::from(d), + MessageValue::Date( d) => Operand::from(d), + MessageValue::Time(t) | + MessageValue::Counter(t) | + MessageValue::Timestamp(t) => Operand::from(t), + MessageValue::Uuid(u) | + MessageValue::Timeuuid( u) => Operand::from( u ), + + MessageValue::List(l) => {Operand::List( l.iter().map( |x| Operand::from(x).to_string()).collect())} + + MessageValue::Rows(r) => { + Operand::Tuple( r.iter().map( |row| row.iter().map( |m| Operand::from(m)).collect()).map( |v| Operand::Tuple(v)).collect()) + } + + MessageValue::NamedRows(r) => { + Operand::Tuple( r.iter().map( |nr| Operand::Map(nr.iter().map( |(k,v)| (k.clone(), Operand::from(v).to_string())).collect())).collect()) + } + + MessageValue::Set(s) => { + Operand::Set( s.iter().map( |m| Operand::from(m).to_string()).collect()) + } + MessageValue::Map(m) => { + Operand::Map( m.iter().map( |(k,v)| (Operand::from(k).to_string(), Operand::from(v).to_string())).collect() ) + } + + MessageValue::FragmentedResponse(t) | + MessageValue::Tuple(t) => { + Operand::Tuple( t.iter().map( |m| Operand::from(m)).collect()) + } + + MessageValue::Udt(d) | + MessageValue::Document(d) => { + Operand::Map( d.iter().map( |(k,v)| (k.clone(),Operand::from(v).to_string())).collect()) + } + + MessageValue::None => { + Operand::Null + } - _ => {} } } } impl From<&Operand> for MessageValue { - fn from(operand: &Operand) -> Self { operand.as_cassandra_type().map_or( MessageValue::None, |x| MessageValue::create_element( x )) } @@ -543,8 +553,8 @@ impl From<&MessageValue> for DataTypeName { MessageValue::Counter(_) => DataTypeName::Counter, MessageValue::Tuple(_) => DataTypeName::Tuple, MessageValue::Udt(_) => DataTypeName::Tuple, - MessageValue::NULL => {}, - None => {}, + MessageValue::NULL => DataTypeName::Custom( "NULL".to_string()), + MessageValue::None => DataTypeName::Custom( "None".to_string()), } } diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index 7bdd1c44d..67eb0eb52 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -1,5 +1,5 @@ use crate::error::ChainResponse; -use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, CQL, Frame}; +use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; use crate::message::MessageValue; use crate::transforms::protect::key_management::{KeyManager, KeyManagerConfig}; use crate::transforms::{Transform, Transforms, Wrapper}; @@ -7,19 +7,15 @@ use anyhow::anyhow; use anyhow::Result; use async_trait::async_trait; use bytes::Bytes; -use itertools::Itertools; use serde::{Deserialize, Serialize}; use sodiumoxide::crypto::secretbox; use sodiumoxide::crypto::secretbox::{Key, Nonce}; -//use sqlparser::ast::{Assignment, Expr, Ident, Query, SetExpr, Statement, Value as SQLValue}; -use std::borrow::BorrowMut; use std::collections::HashMap; use cql3_parser::cassandra_statement::CassandraStatement; use cql3_parser::common::Operand; use cql3_parser::insert::InsertValues; use cql3_parser::select::SelectElement; -use cql3_parser::update::AssignmentElement; -use serde_yaml::seed::from_slice_seed; +use futures::TryFutureExt; use sodiumoxide::hex; use tracing::warn; @@ -98,8 +94,8 @@ impl From for MessageValue { } } -impl From for Operand { - fn from(p: Protected) -> Self { +impl From<&Protected> for Operand { + fn from(p: &Protected) -> Self { match p { Protected::Plaintext(_) => panic!( "tried to move unencrypted value to plaintext without explicitly calling decrypt" @@ -178,73 +174,23 @@ impl ProtectConfig { } } -pub fn get_values_from_insert_or_update_mut(ast: &mut CQL) -> HashMap { - match ast.statement[0] { - Some(stmt) => - match stmt { - CassandraStatement::Insert(insert) => { - match insert.values { - InsertValues::Values(values) => { - let mut result = HashMap::new(); - // if the lengths don't match we will return an empty hashmap. - if values.len() == insert.columns.len() { - for (i, value) in values.iter().enumerate() { - if let Operand::Const(val) = value { - result.insert( insert.columns[i].to_string(), value); - } - } - } else { - // TODO do we need to clear data here? - } - result - } - // TODO parse JSON? - InsertValues::Json(_) => HashMap::new() - } - } - - CassandraStatement::Update(update) => { - let mut result = HashMap::new(); - for assignment in update.assignments { - // the operator adds something like +x or -x to the assignment so it indicates this is not a value - // and thus we should skip it. - if assignment.operator.is_none() { - if let Operand::Const(val) = &assignment.value { - result.insert( assignment.name.to_string(), assignment.value); - } - } else { - // TODO do we need to clear data here? - } - } - result - } - - _ => HashMap::new() - } - - _ => HashMap::new() - } -} - /// determines if columns in the CassandraStatement need to be encrypted and encrypts them. Returns `true` if any columns were changed. -#[async_trait] -fn encrypt_columns( statement : &mut CassandraStatement, columns : &Vec, key_source : &KeyManager, key_id : &str) -> bool { +async fn encrypt_columns( statement : &mut CassandraStatement, columns : &Vec, key_source : &KeyManager, key_id : &str) -> Result { let mut data_changed = false; match statement { CassandraStatement::Insert(insert) => { - let indices = insert.columns.iter().enumerate().filter( |(i,col_name)| columns.contains( col_name )) + let indices :Vec= insert.columns.iter().enumerate().filter( |(i,col_name)| columns.contains( col_name )) .map( |(i,col_name)| i).collect(); match &mut insert.values { - InsertValues::Values(operands) => { - for i in indices { - let operand = operands[i].unwrap(); - let mut protected = Protected::Plaintext( MessageValue::from(operand )); - protected = protected.protect( key_source, key_id ).await?; - std::mem::replace( &mut operands[i], Operand::from(protected)); - data_changed = true; + InsertValues::Values(value_operands) => { + for idx in indices { + let mut protected = Protected::Plaintext(MessageValue::from( &value_operands[idx] )); + protected = protected.protect(key_source, key_id).await?; + std::mem::replace( &mut value_operands[idx], Operand::from( &protected)); + data_changed = true } - } + }, InsertValues::Json(_) => { // TODO parse json and encrypt. } @@ -255,14 +201,16 @@ fn encrypt_columns( statement : &mut CassandraStatement, columns : &Vec, if columns.contains( &assignment.name.column ) { let mut protected = Protected::Plaintext( MessageValue::from(&assignment.value) ); protected = protected.protect( key_source, key_id ).await?; - assignment.value = Operand::from( protected ); + assignment.value = Operand::from( &protected ); data_changed = true; } } } - _ => {} + _ => { + // no other statement are modified + } } - data_changed + Ok(data_changed) } #[async_trait] @@ -283,7 +231,7 @@ impl Transform for Protect { { if let Some((_, columns)) = tables.get_key_value(&namespace[1]) { for mut stmt in query.statement { - data_changed = encrypt_columns(&mut stmt, columns, &self.key_source, &self.key_id ) + data_changed = encrypt_columns(&mut stmt, columns, &self.key_source, &self.key_id ).unwrap() } } } @@ -315,6 +263,7 @@ impl Transform for Protect { .. })) = request.frame() { + if namespace.len() == 2 { if let Some((_keyspace, tables)) = self.keyspace_table_columns.get_key_value(&namespace[0]) @@ -360,50 +309,6 @@ impl Transform for Protect { } } } - let projection: Vec = get_values_from_insert_or_update_mut(query) - .into_keys() - .collect(); - if namespace.len() == 2 { - if let Some((_keyspace, tables)) = - self.keyspace_table_columns.get_key_value(&namespace[0]) - { - if let Some((_table, protect_columns)) = - tables.get_key_value(&namespace[1]) - { - let mut positions: Vec = Vec::new(); - for (i, p) in projection.iter().enumerate() { - if protect_columns.contains(p) { - positions.push(i); - } - } - for row in rows { - for index in &mut positions { - if let Some(v) = row.get_mut(*index) { - if let MessageValue::Bytes(_) = v { - let protected = - Protected::from_encrypted_bytes_value(v) - .await?; - let new_value: MessageValue = protected - .unprotect( - &self.key_source, - &self.key_id, - ) - .await?; - *v = new_value; - invalidate_cache = true; - } else { - warn!( - "Tried decrypting non-blob column" - ) - } - } - } - } - } - } - } - } - } } } } diff --git a/shotover-proxy/src/transforms/query_counter.rs b/shotover-proxy/src/transforms/query_counter.rs index 589eb71ee..ae6cd2034 100644 --- a/shotover-proxy/src/transforms/query_counter.rs +++ b/shotover-proxy/src/transforms/query_counter.rs @@ -6,7 +6,7 @@ use anyhow::Result; use async_trait::async_trait; use metrics::{counter, register_counter}; use serde::Deserialize; -use sqlparser::ast::Statement; + #[derive(Debug, Clone)] pub struct QueryCounter { @@ -34,18 +34,7 @@ impl Transform for QueryCounter { Some(Frame::Cassandra(frame)) => match frame.operation.queries() { Ok(queries) => { for statement in queries { - let query_type = match statement { - Statement::Query(_) => "SELECT", - Statement::Insert { .. } => "INSERT", - Statement::Copy { .. } => "COPY", - Statement::Update { .. } => "UPDATE", - Statement::Delete { .. } => "DELETE", - Statement::CreateTable { .. } => "CREATE TABLE", - Statement::AlterTable { .. } => "ALTER TABLE", - Statement::Drop { .. } => "DROP", - _ => "UNRECOGNISED CQL", - }; - counter!("query_count", 1, "name" => self.counter_name.clone(), "query" => query_type, "type" => "cassandra"); + counter!("query_count", 1, "name" => self.counter_name.clone(), "query" => statement.short_name(), "type" => "cassandra"); } } Err(_) => { diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index 4f3aa851a..73d6ec3bf 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -1,6 +1,6 @@ use crate::config::topology::TopicHolder; use crate::error::ChainResponse; -use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, CQL, Frame, RedisFrame}; +use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame, RedisFrame}; use crate::message::{Message, MessageValue, Messages, QueryType}; use crate::transforms::chain::TransformChain; use crate::transforms::{ @@ -10,14 +10,10 @@ use anyhow::{anyhow, bail, Result}; use async_trait::async_trait; use bytes::{BufMut, Bytes, BytesMut}; use cassandra_protocol::frame::Version; -use itertools::Itertools; use serde::Deserialize; -use sqlparser::ast::{Assignment, BinaryOperator, Expr, Ident, Query, SetExpr, Statement, Value}; -use std::borrow::Borrow; use std::collections::HashMap; use cql3_parser::cassandra_statement::CassandraStatement; use cql3_parser::common::{Operand, PrimaryKey, RelationElement, RelationOperator, WhereClause}; -use tracing::info; const TRUE: [u8; 1] = [0x1]; const FALSE: [u8; 1] = [0x0]; @@ -62,7 +58,10 @@ impl SimpleRedisCache { "SimpleRedisCache" } - async fn get_or_update_from_cache(&mut self, mut messages: Messages) -> ChainResponse { + async fn get_or_update_from_cache( + &mut self, + mut messages_cass_request: Messages + ) -> ChainResponse { // This function is a little hard to follow, so heres an overview. // We have 4 vecs of messages, each vec can be considered its own stage of processing. // 1. messages_cass_request: @@ -80,28 +79,22 @@ impl SimpleRedisCache { // - we can get away with this because batches can only contain INSERT/UPDATE/DELETE and therefore always contain either an ERROR or a VOID RESULT // + if the request is a CassandraOperation::Query then we consume a single message from messages_redis_response converting it to a cassandra response // * These are the cassandra responses that we return from the function. - let mut stream_ids = Vec::with_capacity(messages.len()); - for message in &mut messages { - if let Some(Frame::Cassandra(frame)) = message.frame() { - stream_ids.push(frame.stream_id); - } else { - bail!("Failed to parse cassandra message"); - } - if let Some(table_name) = message.namespace().map(|x| x.join(".")) { - *message = match message.frame() { - Some(Frame::Cassandra(CassandraFrame { - operation: CassandraOperation::Query { query, .. }, - .. - })) => { - let table_cache_schema = self - .caching_schema - .get(&table_name) - .ok_or_else(|| anyhow!("{table_name} not a caching table"))?; - - Message::from_frame(Frame::Redis(build_redis_ast_from_cql3( - query, - table_cache_schema, - )?)) + + let mut messages_redis_request = Vec::with_capacity(messages_cass_request.len()); + for cass_request in &mut messages_cass_request { + if let Some(table_name) = cass_request.namespace().map(|x| x.join(".")) { + match cass_request.frame() { + Some(Frame::Cassandra(frame)) => { + for query in frame.operation.queries()? { + let table_cache_schema = self + .caching_schema + .get(&table_name) + .ok_or_else(|| anyhow!("{table_name} not a caching table"))?; + + messages_redis_request.push(Message::from_frame(Frame::Redis( + build_redis_ast_from_cql3(query, table_cache_schema)?, + ))); + } } message => bail!("cannot fetch {message:?} from cache"), } @@ -143,7 +136,6 @@ impl SimpleRedisCache { Ok(messages_cass_request) } } - fn append_prefix_min(min: &mut Vec) { if min.is_empty() { min.push(b'['); @@ -267,11 +259,11 @@ fn build_redis_frames_from_where_clause( where_clause : &Vec, t if x.is_none() { return None } - let y = x.iter().filter(|x| x.oper == RelationOperator::Equal).nth(0); + let y = x.unwrap().iter().filter(|x| x.oper == RelationOperator::Equal).nth(0); if y.is_none() { return None } - Some(y.value) + Some(&y.unwrap().value) }) .fold(BytesMut::new(), |mut acc, v| { if let Some(v) = v { @@ -287,16 +279,16 @@ fn build_redis_frames_from_where_clause( where_clause : &Vec, t ] } fn build_redis_ast_from_cql3 ( - ast: &CQL, + statement: &CassandraStatement, table_cache_schema: &TableCacheSchema, ) -> Result { - match &ast.statement[0] { + match statement { CassandraStatement::Select(select) => { if select.where_clause.is_some() { Ok(RedisFrame::Array( build_redis_frames_from_where_clause( &select.where_clause.unwrap(),table_cache_schema))) } else { - Err(anyhow!("Cant build query from statement: {}", &ast.statement[0])) + Err(anyhow!("Cant build query from statement: {}", statement)) } } CassandraStatement::Insert(insert) => { @@ -315,7 +307,7 @@ fn build_redis_ast_from_cql3 ( ]; let mut map = HashMap::new(); - value_map.iter().for_each(|(key,value)| {map.insert( key, MessageValue::from( value ));}); + value_map.iter().for_each(|(key,value)| {map.insert( key.clone(), MessageValue::from( *value ));}); Ok(RedisFrame::Array(add_values_to_redis_frames(table_cache_schema, map, redis_frames))) } CassandraStatement::Update(update) => { @@ -324,14 +316,13 @@ fn build_redis_ast_from_cql3 ( for x in update.assignments { // skip any columns with +/- modifiers. if x.operator.is_none() { - map.insert(x.name.to_string(), MessageValue::from(x.value)) + map.insert(x.name.to_string(), MessageValue::from(&x.value)) } } Ok(RedisFrame::Array(add_values_to_redis_frames(table_cache_schema, map, redis_frames ))) } - statement => Err(anyhow!("Cant build query from statement: {}", statement)), } } @@ -437,19 +428,17 @@ mod test { use crate::transforms::debug::printer::DebugPrinter; use crate::transforms::null::Null; use crate::transforms::redis::cache::{ - build_redis_ast_from_sql, SimpleRedisCache, TableCacheSchema, + build_redis_ast_from_cql3, SimpleRedisCache, TableCacheSchema, }; use crate::transforms::{Transform, Transforms}; use bytes::Bytes; - use sqlparser::ast::Statement; - use sqlparser::dialect::GenericDialect; - use sqlparser::parser::Parser; use std::collections::HashMap; + use cql3_parser::cassandra_ast::CassandraAST; + use cql3_parser::cassandra_statement::CassandraStatement; - fn build_query(query_string: &str) -> Statement { - Parser::parse_sql(&GenericDialect {}, query_string) - .unwrap() - .remove(0) + fn build_query(query_string: &str) -> CassandraStatement { + let ast = CassandraAST::new( query_string ); + ast.statement } #[test] @@ -461,7 +450,7 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x = 123 AND y = 965"); - let query = build_redis_ast_from_sql(&ast, &table_cache_schema).unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), @@ -482,7 +471,7 @@ mod test { let ast = build_query("INSERT INTO foo (z, v) VALUES (1, 123)"); - let query = build_redis_ast_from_sql(&ast, &table_cache_schema).unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZADD")), @@ -502,7 +491,7 @@ mod test { }; let ast = build_query("INSERT INTO foo (z, c, v) VALUES (1, 'yo' , 123)"); - let query = build_redis_ast_from_sql(&ast, &table_cache_schema).unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZADD")), @@ -523,7 +512,7 @@ mod test { let ast = build_query("UPDATE foo SET c = 'yo', v = 123 WHERE z = 1"); - let query = build_redis_ast_from_sql(&ast, &table_cache_schema).unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZADD")), @@ -544,11 +533,11 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x = 123 AND y = 965"); - let query_one = build_redis_ast_from_sql(&ast, &table_cache_schema).unwrap(); + let query_one = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); let ast = build_query("SELECT * FROM foo WHERE y = 965 AND z = 1 AND x = 123"); - let query_two = build_redis_ast_from_sql(&ast, &table_cache_schema).unwrap(); + let query_two = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); // Semantically databases treat the order of AND clauses differently, Cassandra however requires clustering key predicates be in order // So here we will just expect the order is correct in the query. TODO: we may need to revisit this as support for other databases is added @@ -564,7 +553,7 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x > 123 AND x < 999"); - let query = build_redis_ast_from_sql(&ast, &table_cache_schema).unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), @@ -585,7 +574,7 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x >= 123 AND x <= 999"); - let query = build_redis_ast_from_sql(&ast, &table_cache_schema).unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), @@ -606,7 +595,7 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1"); - let query = build_redis_ast_from_sql(&ast, &table_cache_schema).unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), @@ -627,7 +616,7 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND y = 2"); - let query = build_redis_ast_from_sql(&ast, &table_cache_schema).unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), @@ -648,7 +637,7 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x >= 123"); - let query = build_redis_ast_from_sql(&ast, &table_cache_schema).unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), @@ -661,7 +650,7 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x <= 123"); - let query = build_redis_ast_from_sql(&ast, &table_cache_schema).unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), From 6440910ce19ebb9c72acee8ec315d993bf99ec9f Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Thu, 31 Mar 2022 07:07:08 +0100 Subject: [PATCH 03/60] first code complete --- Cargo.lock | 45 ++++++++-------- shotover-proxy/src/codec/cassandra.rs | 9 ++-- shotover-proxy/src/frame/cassandra.rs | 4 +- .../src/transforms/cassandra/peers_rewrite.rs | 2 +- shotover-proxy/src/transforms/protect/mod.rs | 20 +++---- shotover-proxy/src/transforms/redis/cache.rs | 53 ++++++++++--------- 6 files changed, 68 insertions(+), 65 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6a7c57443..ecf86aa0d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -438,7 +438,7 @@ dependencies = [ [[package]] name = "cql3_parser" version = "0.0.1" -source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#bf4f943c300a382e428eedbd42db9f497145cf55" +source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#01b16122220dbeb7c1148e26fe44115507b0a9e6" dependencies = [ "bigdecimal", "bytes", @@ -1243,10 +1243,11 @@ checksum = "7fb9b38af92608140b86b693604b9ffcc5824240a484d1ecd4795bacb2fe88f3" [[package]] name = "lock_api" -version = "0.4.6" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88943dd7ef4a2e5a4bfa2753aaab3013e34ce2533d1996fb18ef591e315e2b3b" +checksum = "327fa5b6a6940e4699ec49a9beae1ea4845c6bab9314e4f84ac68742139d8c53" dependencies = [ + "autocfg", "scopeguard", ] @@ -1731,7 +1732,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "87f5ec2493a61ac0506c0f4199f99070cbe83857b0337006a30f3e6719b8ef58" dependencies = [ "lock_api", - "parking_lot_core 0.9.1", + "parking_lot_core 0.9.2", ] [[package]] @@ -1750,9 +1751,9 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.1" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28141e0cc4143da2443301914478dc976a61ffdb3f043058310c70df2fed8954" +checksum = "995f667a6c822200b0433ac218e05582f0e2efa1b922a3fd2fbaadc5f87bab37" dependencies = [ "cfg-if", "libc", @@ -2069,18 +2070,18 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.2.12" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ae183fc1b06c149f0c1793e1eb447c8b04bfe46d48e9e48bfb8d2d7ed64ecf0" +checksum = "62f25bc4c7e55e0b0b7a1d43fb893f4fa1361d0abe38b9ce4f323c2adfe6ef42" dependencies = [ "bitflags", ] [[package]] name = "redox_users" -version = "0.4.2" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7776223e2696f1aa4c6b0170e83212f47296a00424305117d013dfe86fb0fe55" +checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" dependencies = [ "getrandom", "redox_syscall", @@ -3268,9 +3269,9 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows-sys" -version = "0.32.0" +version = "0.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3df6e476185f92a12c072be4a189a0210dcdcf512a1891d6dff9edb874deadc6" +checksum = "5acdd78cb4ba54c0045ac14f62d8f94a03d10047904ae2a40afa1e99d8f70825" dependencies = [ "windows_aarch64_msvc", "windows_i686_gnu", @@ -3281,33 +3282,33 @@ dependencies = [ [[package]] name = "windows_aarch64_msvc" -version = "0.32.0" +version = "0.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8e92753b1c443191654ec532f14c199742964a061be25d77d7a96f09db20bf5" +checksum = "17cffbe740121affb56fad0fc0e421804adf0ae00891205213b5cecd30db881d" [[package]] name = "windows_i686_gnu" -version = "0.32.0" +version = "0.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a711c68811799e017b6038e0922cb27a5e2f43a2ddb609fe0b6f3eeda9de615" +checksum = "2564fde759adb79129d9b4f54be42b32c89970c18ebf93124ca8870a498688ed" [[package]] name = "windows_i686_msvc" -version = "0.32.0" +version = "0.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "146c11bb1a02615db74680b32a68e2d61f553cc24c4eb5b4ca10311740e44172" +checksum = "9cd9d32ba70453522332c14d38814bceeb747d80b3958676007acadd7e166956" [[package]] name = "windows_x86_64_gnu" -version = "0.32.0" +version = "0.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c912b12f7454c6620635bbff3450962753834be2a594819bd5e945af18ec64bc" +checksum = "cfce6deae227ee8d356d19effc141a509cc503dfd1f850622ec4b0f84428e1f4" [[package]] name = "windows_x86_64_msvc" -version = "0.32.0" +version = "0.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "504a2476202769977a040c6364301a3f65d0cc9e3fb08600b2bda150a0488316" +checksum = "d19538ccc21819d01deaf88d6a17eae6596a12e9aafdbb97916fb49896d89de9" [[package]] name = "winreg" diff --git a/shotover-proxy/src/codec/cassandra.rs b/shotover-proxy/src/codec/cassandra.rs index 485114d1c..b634b9602 100644 --- a/shotover-proxy/src/codec/cassandra.rs +++ b/shotover-proxy/src/codec/cassandra.rs @@ -359,15 +359,14 @@ mod cassandra_protocol_tests { "0400000307000000350000002e53454c454354202a2046524f4d20737973 74656d2e6c6f63616c205748455245206b6579203d20276c6f63616c27000100" ); + let messages = vec![Message::from_frame(Frame::Cassandra(CassandraFrame { version: Version::V4, stream_id: 3, tracing_id: None, warnings: vec![], operation: CassandraOperation::Query { - query: CQL{ statement: vec![ - CassandraAst::parse( "Select * from system where key = 'local'") - ], has_error: vec![false] }, + query: CQL::parse_from_string("Select * from system where key = 'local'"), /* CQL::Parsed(vec![Statement::Query(Box::new(Query { @@ -438,9 +437,7 @@ mod cassandra_protocol_tests { tracing_id: None, warnings: vec![], operation: CassandraOperation::Query { - query: CQL{ statement: vec![ - CassandraAst::parse( "Select bar from foo") - ], has_error: vec![false] }, + query: CQL::parse_from_string("Select bar from foo"), /*CQL::Parsed(vec![Statement::Insert { or: None, table_name: ObjectName(vec![ diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index 3c616688f..d6eab97e9 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -18,7 +18,7 @@ use cassandra_protocol::frame::{ Direction, Flags, Frame as RawCassandraFrame, Opcode, Serialize, StreamId, Version, }; use cassandra_protocol::query::{QueryParams, QueryValues}; -use cassandra_protocol::types::{AsCassandraType, CBytes, CBytesShort, CInt, CLong}; +use cassandra_protocol::types::{CBytes, CBytesShort, CInt, CLong}; use nonzero_ext::nonzero; use std::convert::TryInto; use std::num::NonZeroU32; @@ -494,8 +494,8 @@ impl CQL { pub fn parse_from_string(cql_query_str: &str) -> Self { let ast = CassandraAST::new(cql_query_str ); CQL { - statement : ast.statements, has_error : ast.has_error(), + statement : ast.statements, } } } diff --git a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs index 4bee6da0d..e5fceaf13 100644 --- a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs @@ -106,7 +106,7 @@ mod test { query::QueryParams, }; - fn create_query_message(query: String) -> Message { + fn create_query_message(query: &str) -> Message { let original = Frame::Cassandra(CassandraFrame { version: Version::V4, stream_id: 0, diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index 67eb0eb52..1d7083b55 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -15,7 +15,6 @@ use cql3_parser::cassandra_statement::CassandraStatement; use cql3_parser::common::Operand; use cql3_parser::insert::InsertValues; use cql3_parser::select::SelectElement; -use futures::TryFutureExt; use sodiumoxide::hex; use tracing::warn; @@ -180,14 +179,15 @@ async fn encrypt_columns( statement : &mut CassandraStatement, columns : &Vec { - let indices :Vec= insert.columns.iter().enumerate().filter( |(i,col_name)| columns.contains( col_name )) - .map( |(i,col_name)| i).collect(); + let indices :Vec= insert.columns.iter().enumerate() + .filter_map( |(i,col_name)| if columns.contains( col_name ) { Some(i) } else { None }) + .collect(); match &mut insert.values { InsertValues::Values(value_operands) => { for idx in indices { let mut protected = Protected::Plaintext(MessageValue::from( &value_operands[idx] )); protected = protected.protect(key_source, key_id).await?; - std::mem::replace( &mut value_operands[idx], Operand::from( &protected)); + value_operands[idx] = Operand::from( &protected ); data_changed = true } }, @@ -230,8 +230,8 @@ impl Transform for Protect { self.keyspace_table_columns.get_key_value(&namespace[0]) { if let Some((_, columns)) = tables.get_key_value(&namespace[1]) { - for mut stmt in query.statement { - data_changed = encrypt_columns(&mut stmt, columns, &self.key_source, &self.key_id ).unwrap() + for stmt in &mut query.statement { + data_changed = encrypt_columns(stmt, columns, &self.key_source, &self.key_id ).await?; } } } @@ -271,7 +271,7 @@ impl Transform for Protect { if let Some((_table, protect_columns)) = tables.get_key_value(&namespace[1]) { - for cassandra_statement in query.statement { + for cassandra_statement in &query.statement { if let CassandraStatement::Select(select) = cassandra_statement { let positions : Vec = select.columns.iter().enumerate() .filter_map( | (i,col)| { @@ -286,9 +286,9 @@ impl Transform for Protect { None } }).collect(); - for row in rows { - for index in positions { - if let Some(v) = row.get_mut(index) { + for row in &mut *rows { + for index in &positions { + if let Some(v) = row.get_mut(*index) { if let MessageValue::Bytes(_) = v { let protected = Protected::from_encrypted_bytes_value(v) diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index 73d6ec3bf..da877572f 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -160,22 +160,23 @@ fn build_zrangebylex_min_max_from_cql3( ) -> Result<()> { let mut bytes = - Vec::from( match operand { + match operand { Operand::Const(value) => { - match value.to_uppercase().as_str() { + Vec::from( + match value.to_uppercase().as_str() { "TRUE" => &TRUE, "FALSE" => &FALSE, _ => value.as_bytes(), - } + }) } Operand::Map(_) | Operand::Set(_) | Operand::List(_) | Operand::Tuple(_) | Operand::Column(_) | - Operand::Func(_) => operand.to_string().as_bytes(), - Operand::Null => &[], - }); + Operand::Func(_) => Vec::from(operand.to_string().as_bytes()), + Operand::Null => vec!(), + }; match operator { RelationOperator::LessThan => { @@ -219,10 +220,10 @@ fn build_zrangebylex_min_max_from_cql3( Ok(()) } -fn build_redis_frames_from_where_clause( where_clause : &Vec, table_cache_schema: &TableCacheSchema) -> Vec { +fn build_redis_frames_from_where_clause( where_clause : &Vec, table_cache_schema: &TableCacheSchema) -> Result> { let mut min: Vec = Vec::new(); let mut max: Vec = Vec::new(); - + let mut had_err = None; where_clause.iter().filter_map(|relation_element| { match &relation_element.obj { @@ -235,10 +236,16 @@ fn build_redis_frames_from_where_clause( where_clause : &Vec, t } }).for_each(|(operator,values)| { for operand in values { - build_zrangebylex_min_max_from_cql3( operator,operand, &mut min, &mut max, ); + let x = build_zrangebylex_min_max_from_cql3( operator,operand, &mut min, &mut max, ); + if x.is_err() { + had_err = x.err() + } } }); + if had_err.is_some() { + return Err(had_err.unwrap()); + } let min = if min.is_empty() { Bytes::from_static(b"-") } else { @@ -266,17 +273,15 @@ fn build_redis_frames_from_where_clause( where_clause : &Vec, t Some(&y.unwrap().value) }) .fold(BytesMut::new(), |mut acc, v| { - if let Some(v) = v { - v.iter().for_each(|vv| acc.extend(MessageValue::from( vv ).into_str_bytes())); - } + v.iter().for_each(|operand| acc.extend(MessageValue::from( operand ).into_str_bytes())); acc }); - vec![ + Ok(vec![ RedisFrame::BulkString("ZRANGEBYLEX".into()), RedisFrame::BulkString(pk.freeze()), RedisFrame::BulkString(min), RedisFrame::BulkString(max), - ] + ]) } fn build_redis_ast_from_cql3 ( statement: &CassandraStatement, @@ -286,7 +291,7 @@ fn build_redis_ast_from_cql3 ( match statement { CassandraStatement::Select(select) => { if select.where_clause.is_some() { - Ok(RedisFrame::Array( build_redis_frames_from_where_clause( &select.where_clause.unwrap(),table_cache_schema))) + Ok(RedisFrame::Array( build_redis_frames_from_where_clause( &select.where_clause.as_ref().unwrap(),table_cache_schema)?)) } else { Err(anyhow!("Cant build query from statement: {}", statement)) } @@ -301,7 +306,7 @@ fn build_redis_ast_from_cql3 ( acc.extend(MessageValue::from(*v).into_str_bytes()); acc }); - let mut redis_frames: Vec = vec![ + let redis_frames: Vec = vec![ RedisFrame::BulkString("ZADD".into()), RedisFrame::BulkString(pk.freeze()), ]; @@ -311,12 +316,12 @@ fn build_redis_ast_from_cql3 ( Ok(RedisFrame::Array(add_values_to_redis_frames(table_cache_schema, map, redis_frames))) } CassandraStatement::Update(update) => { - let mut redis_frames = build_redis_frames_from_where_clause( &update.where_clause, table_cache_schema); + let redis_frames = build_redis_frames_from_where_clause( &update.where_clause, table_cache_schema)?; let mut map = HashMap::new(); - for x in update.assignments { + for x in &update.assignments { // skip any columns with +/- modifiers. if x.operator.is_none() { - map.insert(x.name.to_string(), MessageValue::from(&x.value)) + map.insert(x.name.to_string(), MessageValue::from(&x.value)); } } @@ -337,8 +342,8 @@ fn add_values_to_redis_frames( .range_key .iter() .map(|k| query_values.get(k.as_str()).unwrap()) - .fold(BytesMut::new(), |mut acc, v| { - acc.extend(v.into_str_bytes()); + .fold(BytesMut::new(), |mut acc, message_value| { + acc.extend(&message_value.clone().into_str_bytes()); acc }); @@ -353,13 +358,13 @@ fn add_values_to_redis_frames( Some(v) } }) - .for_each( |v| { + .for_each( |message_value| { redis_frames.push(RedisFrame::BulkString(Bytes::from_static(b"0"))); let mut value = clustering.clone(); if !value.is_empty() { value.put_u8(b':'); } - value.extend(v.into_str_bytes()); + value.extend(message_value.clone().into_str_bytes()); redis_frames.push(RedisFrame::BulkString(value.freeze())); }); @@ -438,7 +443,7 @@ mod test { fn build_query(query_string: &str) -> CassandraStatement { let ast = CassandraAST::new( query_string ); - ast.statement + ast.statements[0].clone() } #[test] From 9c45cbbb5525910ebd109eb748cef41129014a05 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Thu, 31 Mar 2022 07:49:05 +0100 Subject: [PATCH 04/60] fixed clippy issues --- shotover-proxy/src/frame/cassandra.rs | 155 +++++++++--------- shotover-proxy/src/message/mod.rs | 6 +- .../src/transforms/cassandra/peers_rewrite.rs | 6 +- shotover-proxy/src/transforms/protect/mod.rs | 2 +- shotover-proxy/src/transforms/redis/cache.rs | 18 +- 5 files changed, 89 insertions(+), 98 deletions(-) diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index d6eab97e9..d58730741 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -1,5 +1,3 @@ -use std::net::IpAddr; -use std::str::FromStr; use anyhow::{anyhow, Result}; use bytes::Bytes; use cassandra_protocol::compression::Compression; @@ -18,18 +16,20 @@ use cassandra_protocol::frame::{ Direction, Flags, Frame as RawCassandraFrame, Opcode, Serialize, StreamId, Version, }; use cassandra_protocol::query::{QueryParams, QueryValues}; -use cassandra_protocol::types::{CBytes, CBytesShort, CInt, CLong}; -use nonzero_ext::nonzero; -use std::convert::TryInto; -use std::num::NonZeroU32; -use std::slice::IterMut; use cassandra_protocol::types::blob::Blob; use cassandra_protocol::types::cassandra_type::CassandraType; +use cassandra_protocol::types::{CBytes, CBytesShort, CInt, CLong}; use cql3_parser::cassandra_ast::CassandraAST; use cql3_parser::cassandra_statement::CassandraStatement; use cql3_parser::common::Operand; use itertools::Itertools; +use nonzero_ext::nonzero; use sodiumoxide::hex; +use std::convert::TryInto; +use std::net::IpAddr; +use std::num::NonZeroU32; +use std::slice::IterMut; +use std::str::FromStr; use uuid::Uuid; use crate::message::{MessageValue, QueryType}; @@ -81,7 +81,7 @@ pub(crate) fn cell_count(bytes: &[u8]) -> Result { }) } // TODO remove this and use actual default from session. -const DEFAULT_KEYSPACE : &str = ""; +const DEFAULT_KEYSPACE: &str = ""; #[derive(PartialEq, Debug, Clone)] pub struct CassandraFrame { @@ -210,7 +210,9 @@ impl CassandraFrame { .map(|query| BatchStatement { ty: match query.subject { BatchQuerySubj::QueryString(query) => { - BatchStatementType::Statement(CQL::parse_from_string(&query)) + BatchStatementType::Statement(CQL::parse_from_string( + &query, + )) } BatchQuerySubj::PreparedId(id) => { BatchStatementType::PreparedId(id) @@ -243,17 +245,14 @@ impl CassandraFrame { pub fn get_query_type(&self) -> QueryType { /* - Read, - Write, - ReadWrite, - SchemaChange, - PubSubMessage, - */ + Read, + Write, + ReadWrite, + SchemaChange, + PubSubMessage, + */ match &self.operation { - CassandraOperation::Query { - query: cql, - .. - } => match cql.statement.get(0).unwrap() { + CassandraOperation::Query { query: cql, .. } => match cql.statement.get(0).unwrap() { CassandraStatement::AlterKeyspace(_) => QueryType::SchemaChange, CassandraStatement::AlterMaterializedView(_) => QueryType::SchemaChange, CassandraStatement::AlterRole(_) => QueryType::SchemaChange, @@ -288,7 +287,7 @@ impl CassandraFrame { CassandraStatement::ListRoles(_) => QueryType::Read, CassandraStatement::Revoke(_) => QueryType::SchemaChange, CassandraStatement::Select(_) => QueryType::Read, - CassandraStatement::Truncate( _) => QueryType::Write, + CassandraStatement::Truncate(_) => QueryType::Write, CassandraStatement::Update(_) => QueryType::Write, CassandraStatement::Use(_) => QueryType::SchemaChange, CassandraStatement::Unknown(_) => QueryType::Read, @@ -299,10 +298,11 @@ impl CassandraFrame { pub fn namespace(&self) -> Vec { match &self.operation { - CassandraOperation::Query { - query: cql, - .. - } => cql.statement.iter().map( |x|x.get_keyspace( DEFAULT_KEYSPACE )).collect(), + CassandraOperation::Query { query: cql, .. } => cql + .statement + .iter() + .map(|x| x.get_keyspace(DEFAULT_KEYSPACE)) + .collect(), _ => vec![], } } @@ -347,10 +347,7 @@ impl CassandraOperation { /// An Err is returned if the operation cannot contain queries or the queries failed to parse. pub fn queries(&mut self) -> Result> { match self { - CassandraOperation::Query { - query: cql, - .. - } => Ok(cql.statement.iter_mut()), + CassandraOperation::Query { query: cql, .. } => Ok(cql.statement.iter_mut()), // TODO: Return CassandraOperation::Batch queries once we add BATCH parsing to cassandra-protocol _ => Err(anyhow!("This operation cannot contain queries")), } @@ -482,33 +479,33 @@ impl CassandraOperation { #[derive(PartialEq, Debug, Clone)] pub struct CQL { - pub statement : Vec, - pub has_error : bool, + pub statement: Vec, + pub has_error: bool, } impl CQL { pub fn to_query_string(&self) -> String { - self.statement.iter().join( ";" ) + self.statement.iter().join(";") } pub fn parse_from_string(cql_query_str: &str) -> Self { - let ast = CassandraAST::new(cql_query_str ); + let ast = CassandraAST::new(cql_query_str); CQL { - has_error : ast.has_error(), - statement : ast.statements, + has_error: ast.has_error(), + statement: ast.statements, } } } pub trait ToCassandraType { - fn from_string_value(&self, value : &str) -> Option; + fn from_string_value(&self, value: &str) -> Option; fn as_cassandra_type(&self) -> Option; } impl ToCassandraType for Operand { - fn from_string_value(&self, value : &str ) -> Option { + fn from_string_value(&self, value: &str) -> Option { // check for string types - if value.starts_with("'") || value.starts_with("$$") { + if value.starts_with('\'') || value.starts_with("$$") { Some(CassandraType::Varchar(value.to_string())) } else if value.starts_with("0X") || value.starts_with("X'") { let mut chars = value.chars(); @@ -516,60 +513,58 @@ impl ToCassandraType for Operand { chars.next(); let bytes = hex::decode(chars.as_str()).unwrap(); Some(CassandraType::Blob(Blob::from(bytes))) + } else if let Ok(n) = i64::from_str(value) { + Some(CassandraType::Bigint(n)) + } else if let Ok(n) = f64::from_str(value) { + Some(CassandraType::Double(n)) + } else if let Ok(uuid) = Uuid::parse_str(value) { + Some(CassandraType::Uuid(uuid)) + } else if let Ok(ipaddr) = IpAddr::from_str(value) { + Some(CassandraType::Inet(ipaddr)) } else { - let num = i64::from_str(value); - if num.is_ok() { - Some(CassandraType::Bigint(num.unwrap())) - } else { - let num = f64::from_str(value); - if num.is_ok() { - Some(CassandraType::Double(num.unwrap())) - } else { - let uuid = Uuid::parse_str(value); - if uuid.is_ok() { - Some(CassandraType::Uuid(uuid.unwrap())) - } else { - let ipaddr = IpAddr::from_str(value); - if ipaddr.is_ok() { - Some(CassandraType::Inet(ipaddr.unwrap())) - } else { - None - } - } - } - } + None } } fn as_cassandra_type(&self) -> Option { match self { - Operand::Const(value) => { - self.from_string_value( value ) - } - Operand::Map(values) => { - Some(CassandraType::Map(values.iter().map( |(key,value)| (self.from_string_value( key).unwrap(), self.from_string_value(value).unwrap()) ).collect())) - } - Operand::Set( values) => { - Some(CassandraType::Set(values.iter().filter_map( |value| self.from_string_value(value) ).collect())) - } - Operand::List( values) => { - Some(CassandraType::List(values.iter().filter_map( |value| self.from_string_value(value) ).collect())) - } - Operand::Tuple( values) => { - Some(CassandraType::Tuple( values.iter().filter_map( |value| value.as_cassandra_type()).collect())) - } - Operand::Column(value) => { - Some(CassandraType::Ascii( value.to_string() )) - } - Operand::Func(value) => { - Some(CassandraType::Ascii( value.to_string() )) - } + Operand::Const(value) => self.from_string_value(value), + Operand::Map(values) => Some(CassandraType::Map( + values + .iter() + .map(|(key, value)| { + ( + self.from_string_value(key).unwrap(), + self.from_string_value(value).unwrap(), + ) + }) + .collect(), + )), + Operand::Set(values) => Some(CassandraType::Set( + values + .iter() + .filter_map(|value| self.from_string_value(value)) + .collect(), + )), + Operand::List(values) => Some(CassandraType::List( + values + .iter() + .filter_map(|value| self.from_string_value(value)) + .collect(), + )), + Operand::Tuple(values) => Some(CassandraType::Tuple( + values + .iter() + .filter_map(|value| value.as_cassandra_type()) + .collect(), + )), + Operand::Column(value) => Some(CassandraType::Ascii(value.to_string())), + Operand::Func(value) => Some(CassandraType::Ascii(value.to_string())), Operand::Null => Some(CassandraType::Null), } } } - #[derive(PartialEq, Debug, Clone)] pub enum CassandraResult { Rows { diff --git a/shotover-proxy/src/message/mod.rs b/shotover-proxy/src/message/mod.rs index dcd67bcd3..9fc345380 100644 --- a/shotover-proxy/src/message/mod.rs +++ b/shotover-proxy/src/message/mod.rs @@ -478,7 +478,7 @@ impl From<&MessageValue> for Operand { MessageValue::List(l) => {Operand::List( l.iter().map( |x| Operand::from(x).to_string()).collect())} MessageValue::Rows(r) => { - Operand::Tuple( r.iter().map( |row| row.iter().map( |m| Operand::from(m)).collect()).map( |v| Operand::Tuple(v)).collect()) + Operand::Tuple( r.iter().map( |row| row.iter().map( Operand::from ).collect()).map( Operand::Tuple ).collect()) } MessageValue::NamedRows(r) => { @@ -494,7 +494,7 @@ impl From<&MessageValue> for Operand { MessageValue::FragmentedResponse(t) | MessageValue::Tuple(t) => { - Operand::Tuple( t.iter().map( |m| Operand::from(m)).collect()) + Operand::Tuple( t.iter().map( Operand::from ).collect()) } MessageValue::Udt(d) | @@ -512,7 +512,7 @@ impl From<&MessageValue> for Operand { impl From<&Operand> for MessageValue { fn from(operand: &Operand) -> Self { - operand.as_cassandra_type().map_or( MessageValue::None, |x| MessageValue::create_element( x )) + operand.as_cassandra_type().map_or( MessageValue::None, MessageValue::create_element) } } diff --git a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs index e5fceaf13..1ed61699d 100644 --- a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs @@ -163,14 +163,14 @@ mod test { #[test] fn test_is_system_peers_v2() { assert!(is_system_peers(&mut create_query_message( - "SELECT * FROM system.peers_v2;".into() + "SELECT * FROM system.peers_v2;" ))); assert!(!is_system_peers(&mut create_query_message( - "SELECT * FROM not_system.peers_v2;".into() + "SELECT * FROM not_system.peers_v2;" ))); - assert!(!is_system_peers(&mut create_query_message("".into()))); + assert!(!is_system_peers(&mut create_query_message(""))); } #[test] diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index 1d7083b55..9a661c406 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -174,7 +174,7 @@ impl ProtectConfig { } /// determines if columns in the CassandraStatement need to be encrypted and encrypts them. Returns `true` if any columns were changed. -async fn encrypt_columns( statement : &mut CassandraStatement, columns : &Vec, key_source : &KeyManager, key_id : &str) -> Result { +async fn encrypt_columns( statement : &mut CassandraStatement, columns : &[String], key_source : &KeyManager, key_id : &str) -> Result { let mut data_changed = false; match statement { diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index da877572f..495d07603 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -220,7 +220,7 @@ fn build_zrangebylex_min_max_from_cql3( Ok(()) } -fn build_redis_frames_from_where_clause( where_clause : &Vec, table_cache_schema: &TableCacheSchema) -> Result> { +fn build_redis_frames_from_where_clause( where_clause : &[RelationElement], table_cache_schema: &TableCacheSchema) -> Result> { let mut min: Vec = Vec::new(); let mut max: Vec = Vec::new(); let mut had_err = None; @@ -243,8 +243,8 @@ fn build_redis_frames_from_where_clause( where_clause : &Vec, t } }); - if had_err.is_some() { - return Err(had_err.unwrap()); + if let Some(e) = had_err { + return Err(e); } let min = if min.is_empty() { Bytes::from_static(b"-") @@ -263,13 +263,9 @@ fn build_redis_frames_from_where_clause( where_clause : &Vec, t .iter() .filter_map(|k| { let x = where_columns.get(k); - if x.is_none() { - return None - } - let y = x.unwrap().iter().filter(|x| x.oper == RelationOperator::Equal).nth(0); - if y.is_none() { - return None - } + x?; + let y = x.unwrap().iter().find(|x| x.oper == RelationOperator::Equal); + y?; Some(&y.unwrap().value) }) .fold(BytesMut::new(), |mut acc, v| { @@ -291,7 +287,7 @@ fn build_redis_ast_from_cql3 ( match statement { CassandraStatement::Select(select) => { if select.where_clause.is_some() { - Ok(RedisFrame::Array( build_redis_frames_from_where_clause( &select.where_clause.as_ref().unwrap(),table_cache_schema)?)) + Ok(RedisFrame::Array( build_redis_frames_from_where_clause( select.where_clause.as_ref().unwrap(),table_cache_schema)?)) } else { Err(anyhow!("Cant build query from statement: {}", statement)) } From ed1aeeafd6424245a58432dc6566a0f162116056 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Thu, 31 Mar 2022 09:44:21 +0100 Subject: [PATCH 05/60] partial fix for cache processing --- Cargo.lock | 97 -------------------- shotover-proxy/Cargo.toml | 2 +- shotover-proxy/src/transforms/redis/cache.rs | 64 +++++++++---- 3 files changed, 49 insertions(+), 114 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ecf86aa0d..43c37eeb3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,21 +2,6 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "addr2line" -version = "0.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9ecd88a8c8378ca913a680cd98f0f13ac67383d35993f86c90a70e3f137816b" -dependencies = [ - "gimli", -] - -[[package]] -name = "adler" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" - [[package]] name = "ahash" version = "0.7.6" @@ -124,21 +109,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" -[[package]] -name = "backtrace" -version = "0.3.64" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e121dee8023ce33ab248d9ce1493df03c3b38a659b240096fcbd7048ff9c31f" -dependencies = [ - "addr2line", - "cc", - "cfg-if", - "libc", - "miniz_oxide", - "object", - "rustc-demangle", -] - [[package]] name = "base64" version = "0.13.0" @@ -269,25 +239,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a4f925191b4367301851c6d99b09890311d74b0d43f274c0b34c86d308a3663" -[[package]] -name = "cassandra-cpp" -version = "0.17.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd909919df0f560903b92cbcc06642e7b8957bb96ffb50808f9e0b82d3c09f25" -dependencies = [ - "cassandra-cpp-sys", - "error-chain", - "parking_lot 0.12.0", - "slog", - "uuid", -] - -[[package]] -name = "cassandra-cpp-sys" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eadde404d13254d8592f8aaa946a1fb7a2769c137df56ed29a3be5996973a156" - [[package]] name = "cassandra-protocol" version = "1.1.0" @@ -741,16 +692,6 @@ dependencies = [ "libc", ] -[[package]] -name = "error-chain" -version = "0.12.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d2f06b9cac1506ece98fe3231e3cc9c4410ec3d5b1f24ae1c8946f0742cdefc" -dependencies = [ - "backtrace", - "version_check", -] - [[package]] name = "fastrand" version = "1.7.0" @@ -919,12 +860,6 @@ dependencies = [ "wasi 0.10.2+wasi-snapshot-preview1", ] -[[package]] -name = "gimli" -version = "0.26.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78cc372d058dcf6d5ecd98510e7fbc9e5aec4d21de70f65fea8fecebcd881bd4" - [[package]] name = "governor" version = "0.4.2" @@ -1387,16 +1322,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" -[[package]] -name = "miniz_oxide" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a92518e98c078586bc6c934028adcca4c92a53d6a958196de835170a01d84e4b" -dependencies = [ - "adler", - "autocfg", -] - [[package]] name = "mio" version = "0.8.2" @@ -1625,15 +1550,6 @@ dependencies = [ "libc", ] -[[package]] -name = "object" -version = "0.27.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67ac1d3f9a1d3616fd9a60c8d74296f22406a238b6a72f5cc1e6f314df4ffbf9" -dependencies = [ - "memchr", -] - [[package]] name = "once_cell" version = "1.10.0" @@ -2242,12 +2158,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "rustc-demangle" -version = "0.1.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342" - [[package]] name = "rustc_version" version = "0.4.0" @@ -2477,7 +2387,6 @@ dependencies = [ "bytes", "bytes-utils", "cached", - "cassandra-cpp", "cassandra-protocol", "clap 3.1.6", "cql3_parser", @@ -2567,12 +2476,6 @@ version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5" -[[package]] -name = "slog" -version = "2.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8347046d4ebd943127157b94d63abb990fcf729dc4e9978927fdf4ac3c998d06" - [[package]] name = "smallvec" version = "1.8.0" diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml index c6389f6c9..9ab38d268 100644 --- a/shotover-proxy/Cargo.toml +++ b/shotover-proxy/Cargo.toml @@ -84,7 +84,7 @@ threadpool = "1.0" tokio-io-timeout = "1.1.1" num_cpus = "1.0" serial_test = "0.6.0" -cassandra-cpp = "0.17.0" +#cassandra-cpp = "0.17.0" test-helpers = { path = "../test-helpers" } hex-literal = "0.3.3" nix = "0.23.0" diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index 495d07603..fb19ac6a9 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -224,21 +224,19 @@ fn build_redis_frames_from_where_clause( where_clause : &[RelationElement], tabl let mut min: Vec = Vec::new(); let mut max: Vec = Vec::new(); let mut had_err = None; - where_clause.iter().filter_map(|relation_element| - { - match &relation_element.obj { - Operand::Column(name) => { - if table_cache_schema.partition_key.contains(name) { - Some((&relation_element.oper,&relation_element.value)) - } else { None } - }, - _ => None - } - }).for_each(|(operator,values)| { - for operand in values { - let x = build_zrangebylex_min_max_from_cql3( operator,operand, &mut min, &mut max, ); - if x.is_err() { - had_err = x.err() + + let where_columns = WhereClause::get_column_relation_element_map( where_clause ); + + // process the partition key + where_columns.iter().filter(|(name,_relation_elements)| { + ! table_cache_schema.partition_key.contains( name ) + }).for_each( |(_name,relation_elements)| { + for relation_element in relation_elements { + for operand in &relation_element.value { + let x = build_zrangebylex_min_max_from_cql3(&relation_element.oper, &operand, &mut min, &mut max, ); + if x.is_err() { + had_err = x.err() + } } } }); @@ -257,7 +255,6 @@ fn build_redis_frames_from_where_clause( where_clause : &[RelationElement], tabl Bytes::from(max) }; - let where_columns = WhereClause::get_column_relation_element_map( where_clause ); let pk = table_cache_schema .partition_key .iter() @@ -436,6 +433,7 @@ mod test { use std::collections::HashMap; use cql3_parser::cassandra_ast::CassandraAST; use cql3_parser::cassandra_statement::CassandraStatement; + use tls_parser::nom::AsBytes; fn build_query(query_string: &str) -> CassandraStatement { let ast = CassandraAST::new( query_string ); @@ -460,6 +458,40 @@ mod test { RedisFrame::BulkString(Bytes::from_static(b"]123:965")), ]); + if let RedisFrame::Array(v)= query { + assert_eq!(4,v.len()); + let mut iter = v.iter(); + if let RedisFrame::BulkString( b ) = iter.next().unwrap() { + assert_eq!( b"ZRANGEBYLEX", b.as_bytes()); + } + if let RedisFrame::BulkString( b ) = iter.next().unwrap() { + assert_eq!( b"1", b.as_bytes()); + } + + if let RedisFrame::BulkString( b ) = iter.next().unwrap() { + if b.starts_with( b"[123:") { + assert_eq!( b"[123:965", b.as_bytes()); + } else { + assert_eq!( b"[965:123", b.as_bytes()); + } + } else { + assert!(false); + } + + if let RedisFrame::BulkString( b ) = iter.next().unwrap() { + if b.starts_with( b"]123:") { + assert_eq!( b"]123:965", b.as_bytes()); + } else { + assert_eq!( b"]965:123", b.as_bytes()); + } + } else { + assert!(false); + } + + } else { + assert!(false) + } + assert_eq!(expected, query); } From 5443078bfc91fbaf9452ed73601e8c81e49e970d Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Fri, 1 Apr 2022 11:10:01 +0100 Subject: [PATCH 06/60] updated tests --- Cargo.lock | 121 +++++++- shotover-proxy/Cargo.toml | 2 +- shotover-proxy/src/codec/cassandra.rs | 4 +- shotover-proxy/src/frame/cassandra.rs | 42 ++- shotover-proxy/src/message/mod.rs | 43 +-- .../src/transforms/cassandra/peers_rewrite.rs | 290 +++++++++++++----- shotover-proxy/src/transforms/protect/mod.rs | 4 +- shotover-proxy/src/transforms/redis/cache.rs | 179 ++++++----- 8 files changed, 448 insertions(+), 237 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 43c37eeb3..0e5d40bcb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,21 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "addr2line" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9ecd88a8c8378ca913a680cd98f0f13ac67383d35993f86c90a70e3f137816b" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + [[package]] name = "ahash" version = "0.7.6" @@ -109,6 +124,21 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "backtrace" +version = "0.3.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e121dee8023ce33ab248d9ce1493df03c3b38a659b240096fcbd7048ff9c31f" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + [[package]] name = "base64" version = "0.13.0" @@ -239,6 +269,25 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a4f925191b4367301851c6d99b09890311d74b0d43f274c0b34c86d308a3663" +[[package]] +name = "cassandra-cpp" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd909919df0f560903b92cbcc06642e7b8957bb96ffb50808f9e0b82d3c09f25" +dependencies = [ + "cassandra-cpp-sys", + "error-chain", + "parking_lot 0.12.0", + "slog", + "uuid", +] + +[[package]] +name = "cassandra-cpp-sys" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eadde404d13254d8592f8aaa946a1fb7a2769c137df56ed29a3be5996973a156" + [[package]] name = "cassandra-protocol" version = "1.1.0" @@ -307,9 +356,9 @@ dependencies = [ [[package]] name = "clap" -version = "3.1.6" +version = "3.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8c93436c21e4698bacadf42917db28b23017027a4deccb35dbe47a7e7840123" +checksum = "c67e7973e74896f4bba06ca2dcfd28d54f9cb8c035e940a32b88ed48f5f5ecf2" dependencies = [ "atty", "bitflags", @@ -324,9 +373,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "3.1.4" +version = "3.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da95d038ede1a964ce99f49cbe27a7fb538d1da595e4b4f70b8c8f338d17bf16" +checksum = "a3aab4734e083b809aaf5794e14e756d1c798d2c69c7f7de7a09a2f5214993c1" dependencies = [ "heck", "proc-macro-error", @@ -389,7 +438,7 @@ dependencies = [ [[package]] name = "cql3_parser" version = "0.0.1" -source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#01b16122220dbeb7c1148e26fe44115507b0a9e6" +source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#6d7f96a8d7d8fee3c80d9a14f3a2060a33d006b3" dependencies = [ "bigdecimal", "bytes", @@ -692,6 +741,16 @@ dependencies = [ "libc", ] +[[package]] +name = "error-chain" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d2f06b9cac1506ece98fe3231e3cc9c4410ec3d5b1f24ae1c8946f0742cdefc" +dependencies = [ + "backtrace", + "version_check", +] + [[package]] name = "fastrand" version = "1.7.0" @@ -860,6 +919,12 @@ dependencies = [ "wasi 0.10.2+wasi-snapshot-preview1", ] +[[package]] +name = "gimli" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78cc372d058dcf6d5ecd98510e7fbc9e5aec4d21de70f65fea8fecebcd881bd4" + [[package]] name = "governor" version = "0.4.2" @@ -878,9 +943,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.3.12" +version = "0.3.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62eeb471aa3e3c9197aa4bfeabfe02982f6dc96f750486c0bb0009ac58b26d2b" +checksum = "37a82c6d637fc9515a4694bbf1cb2457b79d81ce52b3108bdeea58b07dd34a57" dependencies = [ "bytes", "fnv", @@ -891,7 +956,7 @@ dependencies = [ "indexmap", "slab", "tokio", - "tokio-util 0.6.9", + "tokio-util 0.7.1", "tracing", ] @@ -1322,6 +1387,16 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" +[[package]] +name = "miniz_oxide" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a92518e98c078586bc6c934028adcca4c92a53d6a958196de835170a01d84e4b" +dependencies = [ + "adler", + "autocfg", +] + [[package]] name = "mio" version = "0.8.2" @@ -1550,6 +1625,15 @@ dependencies = [ "libc", ] +[[package]] +name = "object" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67ac1d3f9a1d3616fd9a60c8d74296f22406a238b6a72f5cc1e6f314df4ffbf9" +dependencies = [ + "memchr", +] + [[package]] name = "once_cell" version = "1.10.0" @@ -1750,9 +1834,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58893f751c9b0412871a09abd62ecd2a00298c6c83befa223ef98c52aef40cbe" +checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae" [[package]] name = "pktparse" @@ -2158,6 +2242,12 @@ dependencies = [ "tokio", ] +[[package]] +name = "rustc-demangle" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342" + [[package]] name = "rustc_version" version = "0.4.0" @@ -2387,8 +2477,9 @@ dependencies = [ "bytes", "bytes-utils", "cached", + "cassandra-cpp", "cassandra-protocol", - "clap 3.1.6", + "clap 3.1.7", "cql3_parser", "crc16", "criterion", @@ -2476,6 +2567,12 @@ version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5" +[[package]] +name = "slog" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8347046d4ebd943127157b94d63abb990fcf729dc4e9978927fdf4ac3c998d06" + [[package]] name = "smallvec" version = "1.8.0" @@ -2929,7 +3026,7 @@ dependencies = [ [[package]] name = "tree-sitter-cql" version = "0.0.1" -source = "git+https://github.com/Claude-at-Instaclustr/tree-sitter-cql?branch=main#495dffa341b8342312abd895a2bc4b3d316db23a" +source = "git+https://github.com/Claude-at-Instaclustr/tree-sitter-cql?branch=main#c01b7c8682b68cb8fcec00adc667744fdc8db2e2" dependencies = [ "cc", "regex", diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml index 9ab38d268..c6389f6c9 100644 --- a/shotover-proxy/Cargo.toml +++ b/shotover-proxy/Cargo.toml @@ -84,7 +84,7 @@ threadpool = "1.0" tokio-io-timeout = "1.1.1" num_cpus = "1.0" serial_test = "0.6.0" -#cassandra-cpp = "0.17.0" +cassandra-cpp = "0.17.0" test-helpers = { path = "../test-helpers" } hex-literal = "0.3.3" nix = "0.23.0" diff --git a/shotover-proxy/src/codec/cassandra.rs b/shotover-proxy/src/codec/cassandra.rs index b634b9602..434e0c602 100644 --- a/shotover-proxy/src/codec/cassandra.rs +++ b/shotover-proxy/src/codec/cassandra.rs @@ -366,7 +366,7 @@ mod cassandra_protocol_tests { tracing_id: None, warnings: vec![], operation: CassandraOperation::Query { - query: CQL::parse_from_string("Select * from system where key = 'local'"), + query: CQL::parse_from_string("Select * from system.local where key = 'local'"), /* CQL::Parsed(vec![Statement::Query(Box::new(Query { @@ -437,7 +437,7 @@ mod cassandra_protocol_tests { tracing_id: None, warnings: vec![], operation: CassandraOperation::Query { - query: CQL::parse_from_string("Select bar from foo"), + query: CQL::parse_from_string("insert into system.foo (bar) values ('bar2')"), /*CQL::Parsed(vec![Statement::Insert { or: None, table_name: ObjectName(vec![ diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index d58730741..32763e2a9 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use anyhow::{anyhow, Result}; use bytes::Bytes; use cassandra_protocol::compression::Compression; @@ -80,8 +81,6 @@ pub(crate) fn cell_count(bytes: &[u8]) -> Result { _ => nonzero!(1u32), }) } -// TODO remove this and use actual default from session. -const DEFAULT_KEYSPACE: &str = ""; #[derive(PartialEq, Debug, Clone)] pub struct CassandraFrame { @@ -296,15 +295,36 @@ impl CassandraFrame { } } - pub fn namespace(&self) -> Vec { - match &self.operation { - CassandraOperation::Query { query: cql, .. } => cql - .statement - .iter() - .map(|x| x.get_keyspace(DEFAULT_KEYSPACE)) - .collect(), - _ => vec![], - } + /// returns a mapping of table names to (index,statement) pairs, where index is the index in the CQL + /// of the statement. + pub fn get_table_name_statement_map(&self) -> HashMap> { + let mut result: HashMap> = HashMap::new(); + if let CassandraOperation::Query { query: cql, .. } = &self.operation { + + cql.statement.iter().enumerate().for_each( |(idx,statement)| { + let name = match statement { + CassandraStatement::AlterTable(t) => {Some(&t.name)} + CassandraStatement::CreateIndex(i) => {Some(&i.table)} + CassandraStatement::CreateMaterializedView(m) => {Some(&m.table)} + CassandraStatement::CreateTable(t) => {Some(&t.name)} + CassandraStatement::DropTable(t) => {Some(&t.name)} + CassandraStatement::DropTrigger(t) => {Some(&t.table)} + CassandraStatement::Insert(i) => {Some(&i.table_name)} + CassandraStatement::Select(s) => {Some(&s.table_name)} + CassandraStatement::Truncate(t) => {Some(t)} + CassandraStatement::Update(u) => {Some(&u.table_name)} + _ => None + }; + if let Some(k) = name { + if let Some(v) = result.get_mut(k) { + v.push( (idx,statement)); + } else { + result.insert( k.to_string(), vec![(idx,statement)]); + } + } + }); + }; + result } pub fn encode(self) -> RawCassandraFrame { diff --git a/shotover-proxy/src/message/mod.rs b/shotover-proxy/src/message/mod.rs index 9fc345380..022d84c88 100644 --- a/shotover-proxy/src/message/mod.rs +++ b/shotover-proxy/src/message/mod.rs @@ -174,10 +174,14 @@ impl Message { } } - /// Returns None when fails to parse the message - pub fn namespace(&mut self) -> Option> { + + /// Returns the table names found in the message. + /// None if the statements do not contain table names. + pub fn get_table_names(&mut self) -> Option> { match self.frame()? { - Frame::Cassandra(cassandra) => Some(cassandra.namespace()), + Frame::Cassandra(cassandra) => { + Some(cassandra.get_table_name_statement_map().iter().map( |(k,_v)| k.clone()).collect()) + } Frame::Redis(_) => unimplemented!(), Frame::None => Some(vec![]), } @@ -711,39 +715,6 @@ impl MessageValue { CassandraType::Null => MessageValue::NULL, } } - - pub fn into_str_bytes(self) -> Bytes { - match self { - MessageValue::NULL => Bytes::from("".to_string()), - MessageValue::None => Bytes::from("".to_string()), - MessageValue::Bytes(b) => b, - MessageValue::Strings(s) => Bytes::from(s), - MessageValue::Integer(i, _) => Bytes::from(format!("{i}")), - MessageValue::Float(f) => Bytes::from(format!("{f}")), - MessageValue::Boolean(b) => Bytes::from(format!("{b}")), - MessageValue::Inet(i) => Bytes::from(format!("{i}")), - MessageValue::FragmentedResponse(_) => unimplemented!(), - MessageValue::Document(_) => unimplemented!(), - MessageValue::NamedRows(_) => unimplemented!(), - MessageValue::List(_) => unimplemented!(), - MessageValue::Rows(_) => unimplemented!(), - MessageValue::Ascii(_) => unimplemented!(), - MessageValue::Double(_) => unimplemented!(), - MessageValue::Set(_) => unimplemented!(), - MessageValue::Map(_) => unimplemented!(), - MessageValue::Varint(_) => unimplemented!(), - MessageValue::Decimal(_) => unimplemented!(), - MessageValue::Date(_) => unimplemented!(), - MessageValue::Timestamp(_) => unimplemented!(), - MessageValue::Timeuuid(_) => unimplemented!(), - MessageValue::Varchar(_) => unimplemented!(), - MessageValue::Uuid(_) => unimplemented!(), - MessageValue::Time(_) => unimplemented!(), - MessageValue::Counter(_) => unimplemented!(), - MessageValue::Tuple(_) => unimplemented!(), - MessageValue::Udt(_) => unimplemented!(), - } - } } impl From for cassandra_protocol::types::value::Bytes { diff --git a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs index 1ed61699d..68a7f9e41 100644 --- a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs @@ -6,7 +6,10 @@ use crate::{ }; use anyhow::Result; use async_trait::async_trait; +use cql3_parser::cassandra_statement::CassandraStatement; +use cql3_parser::select::SelectElement; use serde::Deserialize; +use std::collections::HashMap; #[derive(Deserialize, Debug, Clone)] pub struct CassandraPeersRewriteConfig { @@ -30,54 +33,89 @@ pub struct CassandraPeersRewrite { impl Transform for CassandraPeersRewrite { async fn transform<'a>(&'a mut self, mut message_wrapper: Wrapper<'a>) -> ChainResponse { // Find the indices of queries to system.peers & system.peers_v2 - let system_peers = message_wrapper + // we need to know which columns in which CQL queries in which messages have system peers + let mut column_names: HashMap> = HashMap::new(); + + message_wrapper .messages .iter_mut() .enumerate() - .filter_map(|(i, m)| if is_system_peers(m) { Some(i) } else { None }) - .collect::>(); + .filter_map(|(i, m)| { + let sys_peers = extract_native_port_column(m); + if sys_peers.is_empty() { + None + } else { + Some((i, sys_peers)) + } + }) + .for_each(|(k, mut v)| { + if let Some(x) = column_names.get_mut(&k) { + x.append(&mut v); + } else { + column_names.insert(k, v); + } + }); let mut response = message_wrapper.call_next_transform().await?; - for i in system_peers { - rewrite_port(&mut response[i], self.port); + for (idx, name_list) in column_names { + rewrite_port(&mut response[idx], &name_list, self.port); } Ok(response) } } -fn is_system_peers(message: &mut Message) -> bool { - if let Some(Frame::Cassandra(_)) = message.frame() { - if let Some(namespace) = message.namespace() { - if namespace.len() > 1 { - return namespace[0] == "system" && namespace[1] == "peers_v2"; +/// determine if the message contains a SELECT from `system.peers_v2` that includes the `native_port` column +/// return a list of (statement index, column index) pairs +fn extract_native_port_column(message: &mut Message) -> Vec { + let mut result: Vec = vec![]; + if let Some(Frame::Cassandra(cassandra)) = message.frame() { + let map = cassandra.get_table_name_statement_map(); + let statements = map.get("system.peers_v2"); + if let Some(v) = statements { + for (_stmt_idx, x) in v { + if let CassandraStatement::Select(select) = x { + select + .columns + .iter() + .for_each(|select_element| match select_element { + SelectElement::Column(col_name) => { + if col_name.name.eq("native_port") { + result.push(col_name.alias_or_name()); + } + } + SelectElement::Star => result.push("native_port".to_string()), + _ => {} + }); + } } } } - - false + result } /// Rewrite the `native_port` field in the results from a query to `system.peers_v2` table /// Only Cassandra queries to the `system.peers` table found via the `is_system_peers` function should be passed to this -fn rewrite_port(message: &mut Message, new_port: u32) { +fn rewrite_port(message: &mut Message, column_names: &[String], new_port: u32) { if let Some(Frame::Cassandra(frame)) = message.frame() { - if let CassandraOperation::Result(CassandraResult::Rows { value, metadata }) = - &mut frame.operation + if let CassandraOperation::Result(CassandraResult::Rows { value, metadata }) = &mut frame.operation { - let port_column_index = metadata + let port_column_index : Vec= metadata .col_specs - .iter() - .position(|col| col.name.as_str() == "native_port"); + .iter().enumerate() + .filter_map(|(idx, col)| if column_names.contains(&col.name) { + Some(idx) + } else { None } + ).collect(); - if let Some(i) = port_column_index { - if let MessageValue::Rows(rows) = &mut *value { - for row in rows.iter_mut() { - row[i] = MessageValue::Integer(new_port as i64, IntSize::I32); + if let MessageValue::Rows(rows) = value { + for row in rows.iter_mut() { + for idx in &port_column_index { + row[*idx] = MessageValue::Integer(new_port as i64, IntSize::I32); } - message.invalidate_cache(); } + message.invalidate_cache(); } } else { panic!( @@ -93,6 +131,7 @@ mod test { use super::*; use crate::frame::{CassandraFrame, CQL}; use crate::transforms::cassandra::peers_rewrite::CassandraResult::Rows; + use cassandra_protocol::frame::frame_result::ColType; use cassandra_protocol::{ consistency::Consistency, frame::{ @@ -129,7 +168,7 @@ mod test { Message::from_frame(original) } - fn create_response_message(rows: Vec>) -> Message { + fn create_response_message(col_specs: &[ColSpec], rows: Vec>) -> Message { let original = Frame::Cassandra(CassandraFrame { version: Version::V4, stream_id: 0, @@ -145,14 +184,7 @@ mod test { ks_name: "system".into(), table_name: "peers_v2".into(), }), - col_specs: vec![ColSpec { - table_spec: None, - name: "native_port".into(), - col_type: ColTypeOption { - id: Int, - value: None, - }, - }], + col_specs: col_specs.to_owned(), }, }), }); @@ -162,74 +194,166 @@ mod test { #[test] fn test_is_system_peers_v2() { - assert!(is_system_peers(&mut create_query_message( - "SELECT * FROM system.peers_v2;" - ))); + let v = + extract_native_port_column(&mut create_query_message("SELECT * FROM system.peers_v2;")); + assert_eq!(1, v.len()); + assert_eq!("native_port", v[0]); - assert!(!is_system_peers(&mut create_query_message( - "SELECT * FROM not_system.peers_v2;" - ))); + let v = extract_native_port_column(&mut create_query_message( + "SELECT * FROM not_system.peers_v2;", + )); + assert!(v.is_empty()); - assert!(!is_system_peers(&mut create_query_message(""))); + let v = extract_native_port_column(&mut create_query_message( + "SELECT native_port as foo from system.peers_v2", + )); + assert_eq!(1, v.len()); + assert_eq!("foo", v[0]); + + let v = extract_native_port_column(&mut create_query_message( + "SELECT native_port as foo, native_port from system.peers_v2", + )); + assert_eq!(2, v.len()); + assert_eq!("foo", v[0]); + assert_eq!("native_port", v[1]); } #[test] - fn test_rewrite_port() { + fn test_simple_rewrite_port() { //Test rewrites `native_port` column when included - { - let mut message = create_response_message(vec![ + + let col_spec = vec![ColSpec { + table_spec: None, + name: "native_port".into(), + col_type: ColTypeOption { + id: Int, + value: None, + }, + }]; + let mut message = create_response_message( + &col_spec, + vec![ vec![MessageValue::Integer(9042, IntSize::I32)], vec![MessageValue::Integer(9042, IntSize::I32)], - ]); + ], + ); - rewrite_port(&mut message, 9043); + rewrite_port(&mut message, &["native_port".to_string()], 9043); - let expected = create_response_message(vec![ + let expected = create_response_message( + &col_spec, + vec![ vec![MessageValue::Integer(9043, IntSize::I32)], vec![MessageValue::Integer(9043, IntSize::I32)], - ]); + ], + ); - assert_eq!(message, expected); - } + assert_eq!(message, expected); + } - // Test does not rewrite anything when `native_port` column not included - { - let frame = Frame::Cassandra(CassandraFrame { - version: Version::V4, - stream_id: 0, - tracing_id: None, - warnings: vec![], - operation: CassandraOperation::Result(Rows { - value: MessageValue::Rows(vec![vec![MessageValue::Inet( - "127.0.0.1".parse().unwrap(), - )]]), - metadata: RowsMetadata { - flags: RowsMetadataFlags::GLOBAL_TABLE_SPACE, - columns_count: 1, - paging_state: None, - global_table_spec: Some(TableSpec { - ks_name: "system".into(), - table_name: "peers_v2".into(), - }), - col_specs: vec![ColSpec { - table_spec: None, - name: "peer".into(), - col_type: ColTypeOption { - id: Inet, - value: None, - }, - }], - }, - }), - }); + #[test] + fn test_simple_rewrite_port_no_match() { + let col_spec = vec![ColSpec { + table_spec: None, + name: "peer".into(), + col_type: ColTypeOption { + id: Inet, + value: None, + }, + }]; - let mut original = Message::from_frame(frame); + let mut original = create_response_message( + &col_spec, + vec![ + vec![MessageValue::Inet("127.0.0.1".parse().unwrap())], + vec![MessageValue::Inet("10.123.56.1".parse().unwrap())], + ], + ); - let expected = original.clone(); + let expected = create_response_message( + &col_spec, + vec![ + vec![MessageValue::Inet("127.0.0.1".parse().unwrap())], + vec![MessageValue::Inet("10.123.56.1".parse().unwrap())], + ], + ); - rewrite_port(&mut original, 9043); + rewrite_port( + &mut original, + &["native_port".to_string(), "alias_port".to_string()], + 9043, + ); - assert_eq!(original, expected); - } + assert_eq!(original, expected); + } + + #[test] + fn test_alias_rewrite_port() { + let col_spec = vec![ + ColSpec { + table_spec: None, + name: "native_port".into(), + col_type: ColTypeOption { + id: Int, + value: None, + }, + }, + ColSpec { + table_spec: None, + name: "some_text".into(), + col_type: ColTypeOption { + id: ColType::Varchar, + value: None, + }, + }, + ColSpec { + table_spec: None, + name: "alias_port".into(), + col_type: ColTypeOption { + id: Int, + value: None, + }, + }, + ]; + + let mut original = create_response_message( + &col_spec, + vec![ + vec![ + MessageValue::Integer(9042, IntSize::I32), + MessageValue::Strings("Hello".into()), + MessageValue::Integer(9042, IntSize::I32), + ], + vec![ + MessageValue::Integer(9042, IntSize::I32), + MessageValue::Strings("World".into()), + MessageValue::Integer(9042, IntSize::I32), + ], + ], + ); + + let expected = create_response_message( + &col_spec, + vec![ + vec![ + MessageValue::Integer(9043, IntSize::I32), + MessageValue::Strings("Hello".into()), + MessageValue::Integer(9043, IntSize::I32), + ], + vec![ + MessageValue::Integer(9043, IntSize::I32), + MessageValue::Strings("World".into()), + MessageValue::Integer(9043, IntSize::I32), + ], + ], + ); + + rewrite_port( + &mut original, + &["native_port".to_string(), "alias_port".to_string()], + 9043, + ); + + assert_eq!(original, expected); } } diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index 9a661c406..bcd08730f 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -219,7 +219,7 @@ impl Transform for Protect { // encrypt the values included in any INSERT or UPDATE queries for message in message_wrapper.messages.iter_mut() { let mut data_changed = false; - if let Some(namespace) = message.namespace() { + if let Some(namespace) = message.get_table_names() { if namespace.len() == 2 { if let Some(Frame::Cassandra(CassandraFrame { operation: CassandraOperation::Query { query, .. }, @@ -257,7 +257,7 @@ impl Transform for Protect { .. })) = response.frame() { - if let Some(namespace) = request.namespace() { + if let Some(namespace) = request.get_table_names() { if let Some(Frame::Cassandra(CassandraFrame { operation: CassandraOperation::Query { query, .. }, .. diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index fb19ac6a9..0655850e1 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -1,7 +1,7 @@ use crate::config::topology::TopicHolder; use crate::error::ChainResponse; use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame, RedisFrame}; -use crate::message::{Message, MessageValue, Messages, QueryType}; +use crate::message::{Message, Messages, QueryType}; use crate::transforms::chain::TransformChain; use crate::transforms::{ build_chain_from_config, Transform, Transforms, TransformsConfig, Wrapper, @@ -11,7 +11,7 @@ use async_trait::async_trait; use bytes::{BufMut, Bytes, BytesMut}; use cassandra_protocol::frame::Version; use serde::Deserialize; -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use cql3_parser::cassandra_statement::CassandraStatement; use cql3_parser::common::{Operand, PrimaryKey, RelationElement, RelationOperator, WhereClause}; @@ -62,7 +62,7 @@ impl SimpleRedisCache { &mut self, mut messages_cass_request: Messages ) -> ChainResponse { - // This function is a little hard to follow, so heres an overview. + // This function is a little hard to follow, so here's an overview. // We have 4 vecs of messages, each vec can be considered its own stage of processing. // 1. messages_cass_request: // * the cassandra requests that the function receives. @@ -82,7 +82,8 @@ impl SimpleRedisCache { let mut messages_redis_request = Vec::with_capacity(messages_cass_request.len()); for cass_request in &mut messages_cass_request { - if let Some(table_name) = cass_request.namespace().map(|x| x.join(".")) { + + if let Some(table_name) = cass_request.get_table_names().map(|x| x.join(".")) { match cass_request.frame() { Some(Frame::Cassandra(frame)) => { for query in frame.operation.queries()? { @@ -233,7 +234,7 @@ fn build_redis_frames_from_where_clause( where_clause : &[RelationElement], tabl }).for_each( |(_name,relation_elements)| { for relation_element in relation_elements { for operand in &relation_element.value { - let x = build_zrangebylex_min_max_from_cql3(&relation_element.oper, &operand, &mut min, &mut max, ); + let x = build_zrangebylex_min_max_from_cql3(&relation_element.oper, operand, &mut min, &mut max, ); if x.is_err() { had_err = x.err() } @@ -266,7 +267,7 @@ fn build_redis_frames_from_where_clause( where_clause : &[RelationElement], tabl Some(&y.unwrap().value) }) .fold(BytesMut::new(), |mut acc, v| { - v.iter().for_each(|operand| acc.extend(MessageValue::from( operand ).into_str_bytes())); + v.iter().for_each(|operand| acc.extend(operand.to_string().as_bytes())); acc }); Ok(vec![ @@ -276,6 +277,20 @@ fn build_redis_frames_from_where_clause( where_clause : &[RelationElement], tabl RedisFrame::BulkString(max), ]) } + +fn extract_partition_key( partition_key_columns : &[String], value_map : &BTreeMap) -> Result{ + let pk = partition_key_columns + .iter() + .map(|k| + value_map.get(k.as_str()).unwrap() + ) + .fold(BytesMut::new(), |mut acc, v| { + acc.extend(v.to_string().as_bytes()); + acc + }); + Ok(pk) +} + fn build_redis_ast_from_cql3 ( statement: &CassandraStatement, table_cache_schema: &TableCacheSchema, @@ -290,36 +305,48 @@ fn build_redis_ast_from_cql3 ( } } CassandraStatement::Insert(insert) => { - let value_map : HashMap = insert.get_value_map(); - let pk = table_cache_schema - .partition_key - .iter() - .map(|k| value_map.get(k.as_str()).unwrap()) - .fold(BytesMut::new(), |mut acc, v| { - acc.extend(MessageValue::from(*v).into_str_bytes()); - acc - }); - let redis_frames: Vec = vec![ + // partition key from the value map + // values from the remaining parts of the value map. + let value_map : BTreeMap = insert.get_value_map(); + let pk = extract_partition_key( &table_cache_schema.partition_key, &value_map )?; + let mut redis_frames: Vec = vec![ RedisFrame::BulkString("ZADD".into()), RedisFrame::BulkString(pk.freeze()), ]; - - let mut map = HashMap::new(); - value_map.iter().for_each(|(key,value)| {map.insert( key.clone(), MessageValue::from( *value ));}); - Ok(RedisFrame::Array(add_values_to_redis_frames(table_cache_schema, map, redis_frames))) + add_values_to_redis_frames(table_cache_schema, value_map, &mut redis_frames)?; + Ok(RedisFrame::Array(redis_frames)) } CassandraStatement::Update(update) => { - let redis_frames = build_redis_frames_from_where_clause( &update.where_clause, table_cache_schema)?; - let mut map = HashMap::new(); - for x in &update.assignments { - // skip any columns with +/- modifiers. - if x.operator.is_none() { - map.insert(x.name.to_string(), MessageValue::from(&x.value)); + // only want the partition key built from `equals` statements in the where clause + // and values from the set clause + let where_tree = WhereClause::get_column_relation_element_map(&update.where_clause); + let mut value_map :BTreeMap = where_tree + .iter().filter_map( |(k,v)| { + for relation in v { + if relation.oper == RelationOperator::Equal && relation.value.len() == 1 { + return Some((k.clone(),&relation.value[0])); + } + } + None + }).collect(); + let mut has_err = false; + update.assignments.iter().for_each( |assignment| { + if assignment.operator.is_some() { + has_err = true; + } else { + value_map.insert( assignment.name.to_string(), &assignment.value ); } + }); + if has_err { + return Err(anyhow!("Set values include operations")); } - - Ok(RedisFrame::Array(add_values_to_redis_frames(table_cache_schema, map, redis_frames ))) - + let pk = extract_partition_key( &table_cache_schema.partition_key, &value_map )?; + let mut redis_frames: Vec = vec![ + RedisFrame::BulkString("ZADD".into()), + RedisFrame::BulkString(pk.freeze()), + ]; + add_values_to_redis_frames(table_cache_schema, value_map, &mut redis_frames)?; + Ok(RedisFrame::Array(redis_frames)) } statement => Err(anyhow!("Cant build query from statement: {}", statement)), } @@ -327,41 +354,48 @@ fn build_redis_ast_from_cql3 ( fn add_values_to_redis_frames( table_cache_schema: &TableCacheSchema, - query_values: HashMap, - mut redis_frames : Vec -) -> Vec { + query_values: BTreeMap, + redis_frames : &mut Vec +) -> Result<()> { - let clustering = table_cache_schema + let mut has_err = None; + let mut clustering = table_cache_schema .range_key .iter() - .map(|k| query_values.get(k.as_str()).unwrap()) - .fold(BytesMut::new(), |mut acc, message_value| { - acc.extend(&message_value.clone().into_str_bytes()); + .filter_map(|k| { + if let Some(x) = query_values.get(k.as_str()) { + Some(x) + } else { + has_err = Some(anyhow!( "Clustering column {} missing from statement", k )); + None + } + }) + .fold(BytesMut::new(), |mut acc, operand| { + acc.extend(operand.to_string().as_bytes()); acc }); + if let Some(e) = has_err { + return Err(e); + } + if !clustering.is_empty() { + clustering.put_u8(b':'); + } + redis_frames.push(RedisFrame::BulkString(Bytes::from_static(b"0"))); query_values .iter() .filter_map(|(p, v)| { - if !(table_cache_schema.partition_key.contains(p) || - table_cache_schema.range_key.contains( p)) - { + if table_cache_schema.partition_key.contains(p) || + table_cache_schema.range_key.contains(p) { None - } else { - Some(v) - } + } else { Some(*v)} }) - .for_each( |message_value| { - redis_frames.push(RedisFrame::BulkString(Bytes::from_static(b"0"))); - let mut value = clustering.clone(); - if !value.is_empty() { - value.put_u8(b':'); - } - value.extend(message_value.clone().into_str_bytes()); - redis_frames.push(RedisFrame::BulkString(value.freeze())); + .for_each( |operand| { + clustering.extend(operand.to_string().as_bytes()); }); + redis_frames.push(RedisFrame::BulkString(clustering.freeze())); - redis_frames + Ok(()) } #[async_trait] @@ -433,10 +467,10 @@ mod test { use std::collections::HashMap; use cql3_parser::cassandra_ast::CassandraAST; use cql3_parser::cassandra_statement::CassandraStatement; - use tls_parser::nom::AsBytes; fn build_query(query_string: &str) -> CassandraStatement { let ast = CassandraAST::new( query_string ); + assert!( !ast.has_error() ); ast.statements[0].clone() } @@ -457,41 +491,6 @@ mod test { RedisFrame::BulkString(Bytes::from_static(b"[123:965")), RedisFrame::BulkString(Bytes::from_static(b"]123:965")), ]); - - if let RedisFrame::Array(v)= query { - assert_eq!(4,v.len()); - let mut iter = v.iter(); - if let RedisFrame::BulkString( b ) = iter.next().unwrap() { - assert_eq!( b"ZRANGEBYLEX", b.as_bytes()); - } - if let RedisFrame::BulkString( b ) = iter.next().unwrap() { - assert_eq!( b"1", b.as_bytes()); - } - - if let RedisFrame::BulkString( b ) = iter.next().unwrap() { - if b.starts_with( b"[123:") { - assert_eq!( b"[123:965", b.as_bytes()); - } else { - assert_eq!( b"[965:123", b.as_bytes()); - } - } else { - assert!(false); - } - - if let RedisFrame::BulkString( b ) = iter.next().unwrap() { - if b.starts_with( b"]123:") { - assert_eq!( b"]123:965", b.as_bytes()); - } else { - assert_eq!( b"]965:123", b.as_bytes()); - } - } else { - assert!(false); - } - - } else { - assert!(false) - } - assert_eq!(expected, query); } @@ -530,7 +529,7 @@ mod test { RedisFrame::BulkString(Bytes::from_static(b"ZADD")), RedisFrame::BulkString(Bytes::from_static(b"1")), RedisFrame::BulkString(Bytes::from_static(b"0")), - RedisFrame::BulkString(Bytes::from_static(b"yo:123")), + RedisFrame::BulkString(Bytes::from_static(b"'yo':123")), ]); assert_eq!(expected, query); @@ -551,7 +550,7 @@ mod test { RedisFrame::BulkString(Bytes::from_static(b"ZADD")), RedisFrame::BulkString(Bytes::from_static(b"1")), RedisFrame::BulkString(Bytes::from_static(b"0")), - RedisFrame::BulkString(Bytes::from_static(b"yo:123")), + RedisFrame::BulkString(Bytes::from_static(b"'yo':123")), ]); assert_eq!(expected, query); @@ -574,7 +573,7 @@ mod test { // Semantically databases treat the order of AND clauses differently, Cassandra however requires clustering key predicates be in order // So here we will just expect the order is correct in the query. TODO: we may need to revisit this as support for other databases is added - assert_ne!(query_one, query_two); + assert_eq!(query_one, query_two); } #[test] From d1095e1d5d64b34029727ffcb79b4eff52d237c7 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Mon, 4 Apr 2022 13:05:47 +0100 Subject: [PATCH 07/60] updated CQL to contain only one statement --- Cargo.lock | 2 +- shotover-proxy/src/frame/cassandra.rs | 83 +++++++------ shotover-proxy/src/message/mod.rs | 14 +-- .../src/transforms/cassandra/peers_rewrite.rs | 8 +- shotover-proxy/src/transforms/protect/mod.rs | 112 +++++++++--------- shotover-proxy/src/transforms/redis/cache.rs | 102 +++++++++------- .../cassandra_int_tests/basic_driver_tests.rs | 6 +- 7 files changed, 174 insertions(+), 153 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0e5d40bcb..3218b0ac0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -438,7 +438,7 @@ dependencies = [ [[package]] name = "cql3_parser" version = "0.0.1" -source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#6d7f96a8d7d8fee3c80d9a14f3a2060a33d006b3" +source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#73691a61b89870d3a01a064ee0274502f35f0067" dependencies = [ "bigdecimal", "bytes", diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index 32763e2a9..a227c4266 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use anyhow::{anyhow, Result}; use bytes::Bytes; use cassandra_protocol::compression::Compression; @@ -251,7 +250,7 @@ impl CassandraFrame { PubSubMessage, */ match &self.operation { - CassandraOperation::Query { query: cql, .. } => match cql.statement.get(0).unwrap() { + CassandraOperation::Query { query: cql, .. } => match cql.statement { CassandraStatement::AlterKeyspace(_) => QueryType::SchemaChange, CassandraStatement::AlterMaterializedView(_) => QueryType::SchemaChange, CassandraStatement::AlterRole(_) => QueryType::SchemaChange, @@ -269,7 +268,7 @@ impl CassandraFrame { CassandraStatement::CreateTrigger(_) => QueryType::SchemaChange, CassandraStatement::CreateType(_) => QueryType::SchemaChange, CassandraStatement::CreateUser(_) => QueryType::SchemaChange, - CassandraStatement::DeleteStatement(_) => QueryType::Write, + CassandraStatement::Delete(_) => QueryType::Write, CassandraStatement::DropAggregate(_) => QueryType::SchemaChange, CassandraStatement::DropFunction(_) => QueryType::SchemaChange, CassandraStatement::DropIndex(_) => QueryType::SchemaChange, @@ -295,35 +294,30 @@ impl CassandraFrame { } } + + /// returns a mapping of table names to (index,statement) pairs, where index is the index in the CQL /// of the statement. - pub fn get_table_name_statement_map(&self) -> HashMap> { - let mut result: HashMap> = HashMap::new(); - if let CassandraOperation::Query { query: cql, .. } = &self.operation { - - cql.statement.iter().enumerate().for_each( |(idx,statement)| { - let name = match statement { - CassandraStatement::AlterTable(t) => {Some(&t.name)} - CassandraStatement::CreateIndex(i) => {Some(&i.table)} - CassandraStatement::CreateMaterializedView(m) => {Some(&m.table)} - CassandraStatement::CreateTable(t) => {Some(&t.name)} - CassandraStatement::DropTable(t) => {Some(&t.name)} - CassandraStatement::DropTrigger(t) => {Some(&t.table)} - CassandraStatement::Insert(i) => {Some(&i.table_name)} - CassandraStatement::Select(s) => {Some(&s.table_name)} - CassandraStatement::Truncate(t) => {Some(t)} - CassandraStatement::Update(u) => {Some(&u.table_name)} - _ => None - }; - if let Some(k) = name { - if let Some(v) = result.get_mut(k) { - v.push( (idx,statement)); - } else { - result.insert( k.to_string(), vec![(idx,statement)]); + pub fn get_table_names(&self) -> Vec { + let mut result = vec!(); + match &self.operation { + CassandraOperation::Query { query: cql, .. } => { + if let Some(name) = cql.get_table_name() { + result.push( name.into() ); + } + } + CassandraOperation::Batch( batch ) => { + for q in &batch.queries { + + if let BatchStatementType::Statement(cql) = &q.ty { + if let Some(name) = cql.get_table_name() { + result.push( name.into() ); + } } } - }); - }; + }, + _ => {} + } result } @@ -365,10 +359,13 @@ pub enum CassandraOperation { impl CassandraOperation { /// Return all queries contained within CassandaOperation::Query and CassandraOperation::Batch /// An Err is returned if the operation cannot contain queries or the queries failed to parse. - pub fn queries(&mut self) -> Result> { + /// + /// TODO: This will return a custom iterator type when BATCH support is added + pub fn queries(&mut self) -> Result> { match self { - CassandraOperation::Query { query: cql, .. } => Ok(cql.statement.iter_mut()), + CassandraOperation::Query { query: cql, .. } => Ok(std::iter::once(&mut cql.statement )), // TODO: Return CassandraOperation::Batch queries once we add BATCH parsing to cassandra-protocol + _ => Err(anyhow!("This operation cannot contain queries")), } } @@ -499,20 +496,40 @@ impl CassandraOperation { #[derive(PartialEq, Debug, Clone)] pub struct CQL { - pub statement: Vec, + pub statement: CassandraStatement, pub has_error: bool, } impl CQL { pub fn to_query_string(&self) -> String { - self.statement.iter().join(";") + self.statement.to_string() } + /// the CassandraAST handles multiple queries in a string separated by semi-colons: `;` however + /// CQL only stores one query so this method only returns the first one if there are multiples. pub fn parse_from_string(cql_query_str: &str) -> Self { let ast = CassandraAST::new(cql_query_str); CQL { has_error: ast.has_error(), - statement: ast.statements, + statement: ast.statements.first().unwrap().clone(), + } + } + + /// returns the table name specified in the command if one is present. + pub fn get_table_name(&self) -> Option<&String> { + match &self.statement { + CassandraStatement::AlterTable(t) => { Some(&t.name) } + CassandraStatement::CreateIndex(i) => { Some(&i.table) } + CassandraStatement::CreateMaterializedView(m) => { Some(&m.table) } + CassandraStatement::CreateTable(t) => { Some(&t.name) } + CassandraStatement::Delete(d) => { Some(&d.table_name) } + CassandraStatement::DropTable(t) => { Some(&t.name) } + CassandraStatement::DropTrigger(t) => { Some(&t.table) } + CassandraStatement::Insert(i) => { Some(&i.table_name) } + CassandraStatement::Select(s) => { Some(&s.table_name) } + CassandraStatement::Truncate(t) => { Some(t) } + CassandraStatement::Update(u) => { Some(&u.table_name) } + _ => None } } } diff --git a/shotover-proxy/src/message/mod.rs b/shotover-proxy/src/message/mod.rs index 022d84c88..7c6a000ba 100644 --- a/shotover-proxy/src/message/mod.rs +++ b/shotover-proxy/src/message/mod.rs @@ -174,16 +174,16 @@ impl Message { } } - /// Returns the table names found in the message. /// None if the statements do not contain table names. - pub fn get_table_names(&mut self) -> Option> { - match self.frame()? { - Frame::Cassandra(cassandra) => { - Some(cassandra.get_table_name_statement_map().iter().map( |(k,_v)| k.clone()).collect()) + pub fn get_table_names(&mut self) -> Vec { + match self.frame() { + Some(Frame::Cassandra(cassandra)) => { + cassandra.get_table_names() } - Frame::Redis(_) => unimplemented!(), - Frame::None => Some(vec![]), + Some(Frame::Redis(_)) => unimplemented!(), + Some(Frame::None) => vec![], + _ => unreachable!() } } diff --git a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs index 68a7f9e41..26c0ea5e6 100644 --- a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs @@ -71,11 +71,8 @@ impl Transform for CassandraPeersRewrite { fn extract_native_port_column(message: &mut Message) -> Vec { let mut result: Vec = vec![]; if let Some(Frame::Cassandra(cassandra)) = message.frame() { - let map = cassandra.get_table_name_statement_map(); - let statements = map.get("system.peers_v2"); - if let Some(v) = statements { - for (_stmt_idx, x) in v { - if let CassandraStatement::Select(select) = x { + if let CassandraOperation::Query { query, .. } = &cassandra.operation { + if let CassandraStatement::Select(select) = &query.statement { select .columns .iter() @@ -88,7 +85,6 @@ fn extract_native_port_column(message: &mut Message) -> Vec { SelectElement::Star => result.push("native_port".to_string()), _ => {} }); - } } } } diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index bcd08730f..46da3753b 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -16,6 +16,7 @@ use cql3_parser::common::Operand; use cql3_parser::insert::InsertValues; use cql3_parser::select::SelectElement; use sodiumoxide::hex; +use sqlparser::test_utils::table; use tracing::warn; @@ -219,25 +220,23 @@ impl Transform for Protect { // encrypt the values included in any INSERT or UPDATE queries for message in message_wrapper.messages.iter_mut() { let mut data_changed = false; - if let Some(namespace) = message.get_table_names() { - if namespace.len() == 2 { - if let Some(Frame::Cassandra(CassandraFrame { - operation: CassandraOperation::Query { query, .. }, - .. - })) = message.frame() + + if let Some(Frame::Cassandra(CassandraFrame { + operation: CassandraOperation::Query { query, .. }, + .. + })) = message.frame() + { + if let Some(table_name) = query.get_table_name() { + if let Some((_, tables)) = + self.keyspace_table_columns.get_key_value(table_name) { - if let Some((_, tables)) = - self.keyspace_table_columns.get_key_value(&namespace[0]) - { - if let Some((_, columns)) = tables.get_key_value(&namespace[1]) { - for stmt in &mut query.statement { - data_changed = encrypt_columns(stmt, columns, &self.key_source, &self.key_id ).await?; - } - } + if let Some((_, columns)) = tables.get_key_value(table_name ) { + data_changed = encrypt_columns(&mut query.statement, columns, &self.key_source, &self.key_id).await?; } } } } + if data_changed { message.invalidate_cache(); } @@ -257,53 +256,48 @@ impl Transform for Protect { .. })) = response.frame() { - if let Some(namespace) = request.get_table_names() { - if let Some(Frame::Cassandra(CassandraFrame { - operation: CassandraOperation::Query { query, .. }, - .. - })) = request.frame() - { - - if namespace.len() == 2 { - if let Some((_keyspace, tables)) = - self.keyspace_table_columns.get_key_value(&namespace[0]) + if let Some(Frame::Cassandra(CassandraFrame { + operation: CassandraOperation::Query { query, .. }, + .. + })) = request.frame() + { + if let Some(table_name) = query.get_table_name() { + if let Some((_keyspace, tables)) = + self.keyspace_table_columns.get_key_value(table_name) + { + if let Some((_table, protect_columns)) = + tables.get_key_value(table_name) { - if let Some((_table, protect_columns)) = - tables.get_key_value(&namespace[1]) - { - for cassandra_statement in &query.statement { - if let CassandraStatement::Select(select) = cassandra_statement { - let positions : Vec = select.columns.iter().enumerate() - .filter_map( | (i,col)| { - if let SelectElement::Column(named) = col { - if protect_columns.contains(&named.name) - { - Some(i) - } else { - None - } - } else { - None - } - }).collect(); - for row in &mut *rows { - for index in &positions { - if let Some(v) = row.get_mut(*index) { - if let MessageValue::Bytes(_) = v { - let protected = - Protected::from_encrypted_bytes_value(v) - .await?; - let new_value: MessageValue = protected - .unprotect(&self.key_source, &self.key_id) - .await?; - *v = new_value; - invalidate_cache = true; - } else { - warn!("Tried decrypting non-blob column") - } - } + if let CassandraStatement::Select(select) = &query.statement { + let positions: Vec = select.columns.iter().enumerate() + .filter_map(|(i, col)| { + if let SelectElement::Column(named) = col { + if protect_columns.contains(&named.name) + { + Some(i) + } else { + None } - } + } else { + None + } + }).collect(); + for row in &mut *rows { + for index in &positions { + if let Some(v) = row.get_mut(*index) { + if let MessageValue::Bytes(_) = v { + let protected = + Protected::from_encrypted_bytes_value(v) + .await?; + let new_value: MessageValue = protected + .unprotect(&self.key_source, &self.key_id) + .await?; + *v = new_value; + invalidate_cache = true; + } else { + warn!("Tried decrypting non-blob column") + } + } } } } diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index 0655850e1..f2ae283d1 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -81,59 +81,73 @@ impl SimpleRedisCache { // * These are the cassandra responses that we return from the function. let mut messages_redis_request = Vec::with_capacity(messages_cass_request.len()); - for cass_request in &mut messages_cass_request { + // process only as long as all the cass requests can be answered with from the cache. - if let Some(table_name) = cass_request.get_table_names().map(|x| x.join(".")) { - match cass_request.frame() { - Some(Frame::Cassandra(frame)) => { - for query in frame.operation.queries()? { - let table_cache_schema = self + for cass_request in &mut messages_cass_request { + match cass_request.frame() { + Some(Frame::Cassandra(frame)) => { + // get the statements that have table names + if let CassandraOperation::Query { query, params } = &frame.operation { + if let Some(table_name) = query.get_table_name() { + + // process the table if it is listed in the caching schema + if let Some(table_cache_schema) = self .caching_schema - .get(&table_name) - .ok_or_else(|| anyhow!("{table_name} not a caching table"))?; - - messages_redis_request.push(Message::from_frame(Frame::Redis( - build_redis_ast_from_cql3(query, table_cache_schema)?, - ))); + .get(table_name) { + match query.statement { + CassandraStatement::Insert(_) | + CassandraStatement::Update(_) | + CassandraStatement::Delete(_) => { + let result = build_redis_ast_from_cql3(&query.statement, table_cache_schema); + if result.is_ok() { + messages_redis_request.push(Message::from_frame(Frame::Redis(result.unwrap()))); + } + }, + _ => {} + } + } } } - message => bail!("cannot fetch {message:?} from cache"), + }, + _ => { + // statements that we can not handle } - } else { - bail!("Failed to get message namespace"); } } - - let messages_redis_response = self - .cache_chain - .process_request( - Wrapper::new_with_chain_name(messages_redis_request, self.cache_chain.name.clone()), - "clientdetailstodo".to_string(), - ) - .await?; - - // Replace cass_request messages with cassandra responses in place. - // We reuse the vec like this to save allocations. - let mut messages_redis_response_iter = messages_redis_response.into_iter(); - for cass_request in &mut messages_cass_request { - let mut redis_responses = vec![]; - if let Some(Frame::Cassandra(frame)) = cass_request.frame() { - if let Ok(queries) = frame.operation.queries() { - for _query in queries { - redis_responses.push(messages_redis_response_iter.next()); + // if we can handle all the query from the cache then do so + if ! &messages_redis_request.is_empty() { + let messages_redis_response = self + .cache_chain + .process_request( + Wrapper::new_with_chain_name(messages_redis_request, self.cache_chain.name.clone()), + "clientdetailstodo".to_string(), + ) + .await?; + + // Replace cass_request messages with cassandra responses in place. + // We reuse the vec like this to save allocations. + let mut messages_redis_response_iter = messages_redis_response.into_iter(); + for cass_request in &mut messages_cass_request { + let mut redis_responses = vec![]; + if let Some(Frame::Cassandra(frame)) = cass_request.frame() { + if let Ok(queries) = frame.operation.queries() { + for _query in queries { + redis_responses.push(messages_redis_response_iter.next()); + } + } } + + // TODO: Translate the redis_responses into a cassandra result + *cass_request = Message::from_frame(Frame::Cassandra(CassandraFrame { + version: Version::V4, + operation: CassandraOperation::Result(CassandraResult::Void), + stream_id: cass_request.stream_id().unwrap(), + tracing_id: None, + warnings: vec![], + })); } } - // TODO: Translate the redis_responses into a cassandra result - *cass_request = Message::from_frame(Frame::Cassandra(CassandraFrame { - version: Version::V4, - operation: CassandraOperation::Result(CassandraResult::Void), - stream_id: cass_request.stream_id().unwrap(), - tracing_id: None, - warnings: vec![], - })); - } Ok(messages_cass_request) } } @@ -301,7 +315,7 @@ fn build_redis_ast_from_cql3 ( if select.where_clause.is_some() { Ok(RedisFrame::Array( build_redis_frames_from_where_clause( select.where_clause.as_ref().unwrap(),table_cache_schema)?)) } else { - Err(anyhow!("Cant build query from statement: {}", statement)) + Err(anyhow!("Can't build query from statement: {}", statement)) } } CassandraStatement::Insert(insert) => { @@ -348,7 +362,7 @@ fn build_redis_ast_from_cql3 ( add_values_to_redis_frames(table_cache_schema, value_map, &mut redis_frames)?; Ok(RedisFrame::Array(redis_frames)) } - statement => Err(anyhow!("Cant build query from statement: {}", statement)), + _ => unreachable!(), } } diff --git a/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs b/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs index 93f2da17f..7192b4812 100644 --- a/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs +++ b/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs @@ -1200,9 +1200,9 @@ mod cache { ["1", "2", "3"].into_iter().map(|x| x.to_string()).collect(); assert_eq!(result, expected); - assert_sorted_set_equals(redis_connection, "1", &["1:11", "1:foo"]); - assert_sorted_set_equals(redis_connection, "2", &["2:12", "2:bar"]); - assert_sorted_set_equals(redis_connection, "3", &["3:13", "3:baz"]); + assert_sorted_set_equals(redis_connection, "1", &["1:11", "1:'foo'"]); + assert_sorted_set_equals(redis_connection, "2", &["2:12", "2:'bar'"]); + assert_sorted_set_equals(redis_connection, "3", &["3:13", "3:'baz'"]); } fn assert_sorted_set_equals( From e9a72d364228f1fa305109359e46a741275e5a3f Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Wed, 6 Apr 2022 13:00:21 +0100 Subject: [PATCH 08/60] compiles --- Cargo.lock | 46 +- shotover-proxy/src/codec/cassandra.rs | 98 ++-- shotover-proxy/src/frame/cassandra.rs | 375 +++++++++++-- shotover-proxy/src/message/mod.rs | 103 ++-- .../src/transforms/cassandra/peers_rewrite.rs | 48 +- shotover-proxy/src/transforms/protect/mod.rs | 105 ++-- .../src/transforms/query_counter.rs | 1 - shotover-proxy/src/transforms/redis/cache.rs | 499 ++++++++---------- 8 files changed, 782 insertions(+), 493 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3218b0ac0..934c97775 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -356,9 +356,9 @@ dependencies = [ [[package]] name = "clap" -version = "3.1.7" +version = "3.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c67e7973e74896f4bba06ca2dcfd28d54f9cb8c035e940a32b88ed48f5f5ecf2" +checksum = "71c47df61d9e16dc010b55dba1952a57d8c215dbb533fd13cdd13369aac73b1c" dependencies = [ "atty", "bitflags", @@ -438,7 +438,7 @@ dependencies = [ [[package]] name = "cql3_parser" version = "0.0.1" -source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#73691a61b89870d3a01a064ee0274502f35f0067" +source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#4b04fdc52d1aaea01dfc366e6d4382ad93c65f8e" dependencies = [ "bigdecimal", "bytes", @@ -581,9 +581,9 @@ dependencies = [ [[package]] name = "darling" -version = "0.13.1" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0d720b8683f8dd83c65155f0530560cba68cd2bf395f6513a483caee57ff7f4" +checksum = "dbcc37e3091b4dfd0af76cb0087b9c89b8e03072abc28ae2efc8fdd733bfc5f5" dependencies = [ "darling_core", "darling_macro", @@ -591,9 +591,9 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.13.1" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a340f241d2ceed1deb47ae36c4144b2707ec7dd0b649f894cb39bb595986324" +checksum = "9569a966dba8cd57879b8efd2bf82b5c56bb466e19767a69c560bddee1a27f5c" dependencies = [ "fnv", "ident_case", @@ -605,9 +605,9 @@ dependencies = [ [[package]] name = "darling_macro" -version = "0.13.1" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72c41b3b7352feb3211a0d743dc5700a4e3b60f51bd2b368892d1e0f9a95f44b" +checksum = "efae147148c6380157050146a2040b65dbe91bef6e97aaaa39ef0d469d2eb4af" dependencies = [ "darling_core", "quote", @@ -691,9 +691,9 @@ checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457" [[package]] name = "encoding_rs" -version = "0.8.30" +version = "0.8.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7896dc8abb250ffdda33912550faa54c88ec8b998dec0b2c55ab224921ce11df" +checksum = "9852635589dc9f9ea1b6fe9f05b50ef208c85c834a562f0c6abb1c475736ec2b" dependencies = [ "cfg-if", ] @@ -1914,9 +1914,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.36" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7342d5883fbccae1cc37a2353b09c87c9b0f3afd73f5fb9bba687a1f733b029" +checksum = "ec757218438d5fda206afc041538b2f6d889286160d649a86a24d37e1235afd1" dependencies = [ "unicode-xid", ] @@ -2479,7 +2479,7 @@ dependencies = [ "cached", "cassandra-cpp", "cassandra-protocol", - "clap 3.1.7", + "clap 3.1.8", "cql3_parser", "crc16", "criterion", @@ -2563,9 +2563,9 @@ checksum = "76a77a8fd93886010f05e7ea0720e569d6d16c65329dbe3ec033bbbccccb017b" [[package]] name = "slab" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5" +checksum = "eb703cfe953bccee95685111adeedb76fabe4e97549a58d16f03ea7b9367bb32" [[package]] name = "slog" @@ -2659,9 +2659,9 @@ checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" [[package]] name = "syn" -version = "1.0.90" +version = "1.0.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "704df27628939572cd88d33f171cd6f896f4eaca85252c6e0a72d8d8287ee86f" +checksum = "b683b2b825c8eef438b77c36a06dc262294da3d5a5813fac20da149241dcd44d" dependencies = [ "proc-macro2", "quote", @@ -2975,9 +2975,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.23" +version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa31669fa42c09c34d94d8165dd2012e8ff3c66aca50f3bb226b68f216f2706c" +checksum = "90442985ee2f57c9e1b548ee72ae842f4a9a20e3f417cc38dbc5dc684d9bb4ee" dependencies = [ "lazy_static", "valuable", @@ -2997,9 +2997,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.9" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e0ab7bdc962035a87fba73f3acca9b8a8d0034c2e6f60b84aeaaddddc155dce" +checksum = "b9df98b037d039d03400d9dd06b0f8ce05486b5f25e9a2d7d36196e142ebbc52" dependencies = [ "ansi_term", "lazy_static", @@ -3026,7 +3026,7 @@ dependencies = [ [[package]] name = "tree-sitter-cql" version = "0.0.1" -source = "git+https://github.com/Claude-at-Instaclustr/tree-sitter-cql?branch=main#c01b7c8682b68cb8fcec00adc667744fdc8db2e2" +source = "git+https://github.com/Claude-at-Instaclustr/tree-sitter-cql?branch=main#ead1375c83892111d31540e6d8d0dc563f187939" dependencies = [ "cc", "regex", diff --git a/shotover-proxy/src/codec/cassandra.rs b/shotover-proxy/src/codec/cassandra.rs index 434e0c602..a6c143c31 100644 --- a/shotover-proxy/src/codec/cassandra.rs +++ b/shotover-proxy/src/codec/cassandra.rs @@ -366,57 +366,57 @@ mod cassandra_protocol_tests { tracing_id: None, warnings: vec![], operation: CassandraOperation::Query { - query: CQL::parse_from_string("Select * from system.local where key = 'local'"), + query: CQL::parse_from_string("Select * from system.local where key = 'local'"), /* - CQL::Parsed(vec![Statement::Query(Box::new(Query { - with: None, - body: SetExpr::Select(Box::new(Select { - distinct: false, - top: None, - projection: vec![SelectItem::Wildcard], - from: vec![TableWithJoins { - relation: TableFactor::Table { - name: ObjectName(vec![ - Ident { - value: "system".into(), - quote_style: None, - }, - Ident { - value: "local".into(), - quote_style: None, - }, - ]), - alias: None, - args: vec![], - with_hints: vec![], - }, - joins: vec![], - }], - lateral_views: vec![], - selection: Some(BinaryOp { - left: Box::new(Expr::Identifier(Ident { - value: "key".into(), - quote_style: None, - })), - op: BinaryOperator::Eq, - right: Box::new(Expr::Value(SQLValue::SingleQuotedString( - "local".into(), - ))), - }), - group_by: vec![], - cluster_by: vec![], - distribute_by: vec![], - sort_by: vec![], - having: None, - })), - order_by: vec![], - limit: None, - offset: None, - fetch: None, - lock: None, - params: Default::default() - }))]),*/ + CQL::Parsed(vec![Statement::Query(Box::new(Query { + with: None, + body: SetExpr::Select(Box::new(Select { + distinct: false, + top: None, + projection: vec![SelectItem::Wildcard], + from: vec![TableWithJoins { + relation: TableFactor::Table { + name: ObjectName(vec![ + Ident { + value: "system".into(), + quote_style: None, + }, + Ident { + value: "local".into(), + quote_style: None, + }, + ]), + alias: None, + args: vec![], + with_hints: vec![], + }, + joins: vec![], + }], + lateral_views: vec![], + selection: Some(BinaryOp { + left: Box::new(Expr::Identifier(Ident { + value: "key".into(), + quote_style: None, + })), + op: BinaryOperator::Eq, + right: Box::new(Expr::Value(SQLValue::SingleQuotedString( + "local".into(), + ))), + }), + group_by: vec![], + cluster_by: vec![], + distribute_by: vec![], + sort_by: vec![], + having: None, + })), + order_by: vec![], + limit: None, + offset: None, + fetch: None, + lock: None, + params: Default::default() + }))]),*/ params: QueryParams::default(), }, }))]; diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index a227c4266..1b81da680 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -9,7 +9,7 @@ use cassandra_protocol::frame::frame_query::BodyReqQuery; use cassandra_protocol::frame::frame_request::RequestBody; use cassandra_protocol::frame::frame_response::ResponseBody; use cassandra_protocol::frame::frame_result::{ - BodyResResultPrepared, BodyResResultRows, BodyResResultSetKeyspace, ResResultBody, + BodyResResultPrepared, BodyResResultRows, BodyResResultSetKeyspace, ColSpec, ResResultBody, RowsMetadata, RowsMetadataFlags, }; use cassandra_protocol::frame::{ @@ -18,17 +18,18 @@ use cassandra_protocol::frame::{ use cassandra_protocol::query::{QueryParams, QueryValues}; use cassandra_protocol::types::blob::Blob; use cassandra_protocol::types::cassandra_type::CassandraType; +use cassandra_protocol::types::value::Value; use cassandra_protocol::types::{CBytes, CBytesShort, CInt, CLong}; use cql3_parser::cassandra_ast::CassandraAST; use cql3_parser::cassandra_statement::CassandraStatement; -use cql3_parser::common::Operand; -use itertools::Itertools; +use cql3_parser::common::{Operand, RelationElement}; +use cql3_parser::insert::InsertValues; +use cql3_parser::update::AssignmentOperator; use nonzero_ext::nonzero; use sodiumoxide::hex; use std::convert::TryInto; use std::net::IpAddr; use std::num::NonZeroU32; -use std::slice::IterMut; use std::str::FromStr; use uuid::Uuid; @@ -294,28 +295,24 @@ impl CassandraFrame { } } - - - /// returns a mapping of table names to (index,statement) pairs, where index is the index in the CQL - /// of the statement. + /// returns a list of table names from the CassandraOperation pub fn get_table_names(&self) -> Vec { - let mut result = vec!(); + let mut result = vec![]; match &self.operation { CassandraOperation::Query { query: cql, .. } => { - if let Some(name) = cql.get_table_name() { - result.push( name.into() ); + if let Some(name) = CQL::get_table_name(&cql.statement) { + result.push(name.into()); } } - CassandraOperation::Batch( batch ) => { - for q in &batch.queries { - - if let BatchStatementType::Statement(cql) = &q.ty { - if let Some(name) = cql.get_table_name() { - result.push( name.into() ); - } + CassandraOperation::Batch(batch) => { + for q in &batch.queries { + if let BatchStatementType::Statement(cql) = &q.ty { + if let Some(name) = CQL::get_table_name(&cql.statement) { + result.push(name.into()); } } - }, + } + } _ => {} } result @@ -363,9 +360,8 @@ impl CassandraOperation { /// TODO: This will return a custom iterator type when BATCH support is added pub fn queries(&mut self) -> Result> { match self { - CassandraOperation::Query { query: cql, .. } => Ok(std::iter::once(&mut cql.statement )), + CassandraOperation::Query { query: cql, .. } => Ok(std::iter::once(&mut cql.statement)), // TODO: Return CassandraOperation::Batch queries once we add BATCH parsing to cassandra-protocol - _ => Err(anyhow!("This operation cannot contain queries")), } } @@ -501,6 +497,306 @@ pub struct CQL { } impl CQL { + fn from_value_and_col_spec(value: &Value, col_spec: &ColSpec) -> Operand { + match value { + Value::Some(vec) => { + let cbytes = CBytes::new(vec.clone()); + let message_value = + MessageValue::build_value_from_cstar_col_type(col_spec, &cbytes); + let pmsg_value = &message_value; + pmsg_value.into() + } + Value::Null => Operand::Null, + Value::NotSet => Operand::Null, + } + } + fn set_param_value_by_name( + name: &str, + query_params: &QueryParams, + param_types: &[ColSpec], + ) -> Operand { + if let Some(values) = &query_params.values { + if let QueryValues::NamedValues(value_map) = values { + if let Some(value) = value_map.get(name) { + if let Some(idx) = value_map + .iter() + .enumerate() + .filter_map( + |(idx, (key, _value))| { + if key.eq(name) { + Some(idx) + } else { + None + } + }, + ) + .next() + { + return CQL::from_value_and_col_spec(value, ¶m_types[idx]); + } + } + } + } + Operand::Param(format!(":{}", name)) + } + + fn set_param_value_by_position( + param_idx: &mut usize, + query_params: &QueryParams, + param_types: &[ColSpec], + ) -> Operand { + if let Some(QueryValues::SimpleValues(values)) = &query_params.values { + if let Some(value) = values.get(*param_idx) { + *param_idx += 1; + CQL::from_value_and_col_spec(value, ¶m_types[*param_idx]) + } else { + *param_idx += 1; + Operand::Param("?".into()) + } + } else { + *param_idx += 1; + Operand::Param("?".into()) + } + } + + fn set_operand_if_param( + operand: &Operand, + mut param_idx: &mut usize, + query_params: &QueryParams, + param_types: &[ColSpec], + ) -> Operand { + match operand { + Operand::Tuple(vec) => { + let mut vec2 = Vec::with_capacity(vec.len()); + vec.iter().for_each(|o| { + vec2.push(CQL::set_operand_if_param( + o, + &mut param_idx, + query_params, + param_types, + )) + }); + + Operand::Tuple(vec2) + } + Operand::Param(param_name) => { + if param_name.starts_with("?") { + CQL::set_param_value_by_position(param_idx, query_params, param_types) + } else { + let name = param_name.split_at(0).1; + CQL::set_param_value_by_name(name, query_params, param_types) + } + } + Operand::Collection(vec) => { + let mut vec2 = Vec::with_capacity(vec.len()); + vec.iter().for_each(|o| { + vec2.push(CQL::set_operand_if_param( + o, + &mut param_idx, + query_params, + param_types, + )) + }); + + Operand::Collection(vec2) + } + _ => operand.clone(), + } + } + + fn set_relation_elements_values( + param_idx: &mut usize, + query_params: &QueryParams, + param_types: &[ColSpec], + where_clause: &mut [RelationElement], + ) { + for relation_idx in 0..where_clause.len() { + where_clause[relation_idx].value = CQL::set_operand_if_param( + &where_clause[relation_idx].value, + param_idx, + query_params, + param_types, + ); + } + } + + /// replaces the Operand::Param objects with Operand::Const objects where the parameters are defined in the + /// QueryParameters. + /// This method makes a copy of the CassandraStatement + pub fn set_param_values( + &self, + params: &QueryParams, + param_types: &[ColSpec], + ) -> CassandraStatement { + let mut param_idx: usize = 0; + let mut statement = self.statement.clone(); + match &mut statement { + CassandraStatement::Delete(delete) => { + CQL::set_relation_elements_values( + &mut param_idx, + params, + param_types, + &mut delete.where_clause, + ); + CQL::set_relation_elements_values( + &mut param_idx, + params, + param_types, + &mut delete.if_clause, + ); + } + CassandraStatement::Insert(insert) => { + if let InsertValues::Values(operands) = &mut insert.values { + for operand_idx in 0..operands.len() { + operands[operand_idx] = CQL::set_operand_if_param( + &mut operands[operand_idx], + &mut param_idx, + params, + param_types, + ) + } + } + } + CassandraStatement::Select(select) => { + CQL::set_relation_elements_values( + &mut param_idx, + params, + param_types, + &mut select.where_clause, + ); + } + CassandraStatement::Update(update) => { + for assignment_idx in 0..update.assignments.len() { + let mut assignment_element = &mut update.assignments[assignment_idx]; + assignment_element.value = CQL::set_operand_if_param( + &assignment_element.value, + &mut param_idx, + params, + param_types, + ); + if let Some(assignment_operator) = &assignment_element.operator { + match assignment_operator { + AssignmentOperator::Plus(operand) => { + assignment_element.operator = Option::from( + AssignmentOperator::Plus(CQL::set_operand_if_param( + &operand, + &mut param_idx, + params, + param_types, + )), + ); + } + AssignmentOperator::Minus(operand) => { + assignment_element.operator = Option::from( + AssignmentOperator::Minus(CQL::set_operand_if_param( + &operand, + &mut param_idx, + params, + param_types, + )), + ); + } + } + } + } + CQL::set_relation_elements_values( + &mut param_idx, + params, + param_types, + &mut update.where_clause, + ); + CQL::set_relation_elements_values( + &mut param_idx, + params, + param_types, + &mut update.if_clause, + ); + } + _ => {} + } + statement + } + + fn has_params_in_operand(operand: &Operand) -> bool { + match operand { + Operand::Tuple(vec) | Operand::Collection(vec) => { + for oper in vec { + if CQL::has_params_in_operand(oper) { + return true; + } + } + false + } + Operand::Param(_) => true, + _ => false, + } + } + + fn has_params_in_relation_elements(where_clause: &[RelationElement]) -> bool { + for relation_idx in where_clause { + if CQL::has_params_in_operand(&relation_idx.value) { + return true; + } + } + false + } + + /// Returns true if there are any parameters in the query + pub fn has_params(&self) -> bool { + let mut statement = self.statement.clone(); + match &mut statement { + CassandraStatement::Delete(delete) => { + if CQL::has_params_in_relation_elements(&delete.where_clause) { + return true; + } + if CQL::has_params_in_relation_elements(&delete.if_clause) { + return true; + } + } + CassandraStatement::Insert(insert) => { + if let InsertValues::Values(operands) = &mut insert.values { + for operand_idx in 0..operands.len() { + if let Operand::Param(_) = &operands[operand_idx] { + return true; + } + } + } + } + CassandraStatement::Select(select) => { + return CQL::has_params_in_relation_elements(&select.where_clause); + } + CassandraStatement::Update(update) => { + for assignment_element in &update.assignments { + if let Operand::Param(_) = &assignment_element.value { + return true; + } + if let Some(assignment_operator) = &assignment_element.operator { + match assignment_operator { + AssignmentOperator::Plus(operand) => { + if let Operand::Param(_) = operand { + return true; + } + } + AssignmentOperator::Minus(operand) => { + if let Operand::Param(_) = operand { + return true; + } + } + } + } + } + if CQL::has_params_in_relation_elements(&update.where_clause) { + return true; + } + if CQL::has_params_in_relation_elements(&update.if_clause) { + return true; + } + } + _ => {} + } + false + } + pub fn to_query_string(&self) -> String { self.statement.to_string() } @@ -516,20 +812,20 @@ impl CQL { } /// returns the table name specified in the command if one is present. - pub fn get_table_name(&self) -> Option<&String> { - match &self.statement { - CassandraStatement::AlterTable(t) => { Some(&t.name) } - CassandraStatement::CreateIndex(i) => { Some(&i.table) } - CassandraStatement::CreateMaterializedView(m) => { Some(&m.table) } - CassandraStatement::CreateTable(t) => { Some(&t.name) } - CassandraStatement::Delete(d) => { Some(&d.table_name) } - CassandraStatement::DropTable(t) => { Some(&t.name) } - CassandraStatement::DropTrigger(t) => { Some(&t.table) } - CassandraStatement::Insert(i) => { Some(&i.table_name) } - CassandraStatement::Select(s) => { Some(&s.table_name) } - CassandraStatement::Truncate(t) => { Some(t) } - CassandraStatement::Update(u) => { Some(&u.table_name) } - _ => None + pub fn get_table_name(statement: &CassandraStatement) -> Option<&String> { + match statement { + CassandraStatement::AlterTable(t) => Some(&t.name), + CassandraStatement::CreateIndex(i) => Some(&i.table), + CassandraStatement::CreateMaterializedView(m) => Some(&m.table), + CassandraStatement::CreateTable(t) => Some(&t.name), + CassandraStatement::Delete(d) => Some(&d.table_name), + CassandraStatement::DropTable(t) => Some(&t.name), + CassandraStatement::DropTrigger(t) => Some(&t.table), + CassandraStatement::Insert(i) => Some(&i.table_name), + CassandraStatement::Select(s) => Some(&s.table_name), + CassandraStatement::Truncate(t) => Some(t), + CassandraStatement::Update(u) => Some(&u.table_name), + _ => None, } } } @@ -544,7 +840,7 @@ impl ToCassandraType for Operand { // check for string types if value.starts_with('\'') || value.starts_with("$$") { Some(CassandraType::Varchar(value.to_string())) - } else if value.starts_with("0X") || value.starts_with("X'") { + } else if value.starts_with("0X") || value.starts_with("0x") { let mut chars = value.chars(); chars.next(); chars.next(); @@ -598,6 +894,13 @@ impl ToCassandraType for Operand { Operand::Column(value) => Some(CassandraType::Ascii(value.to_string())), Operand::Func(value) => Some(CassandraType::Ascii(value.to_string())), Operand::Null => Some(CassandraType::Null), + Operand::Param(_) => None, + Operand::Collection(values) => Some(CassandraType::List( + values + .iter() + .filter_map(|value| value.as_cassandra_type()) + .collect(), + )), } } } diff --git a/shotover-proxy/src/message/mod.rs b/shotover-proxy/src/message/mod.rs index 7c6a000ba..99e80da16 100644 --- a/shotover-proxy/src/message/mod.rs +++ b/shotover-proxy/src/message/mod.rs @@ -19,6 +19,7 @@ use cassandra_protocol::{ CBytes, }, }; +use cql3_parser::common::{DataTypeName, Operand}; use itertools::Itertools; use nonzero_ext::nonzero; use num::BigInt; @@ -28,7 +29,6 @@ use std::collections::{BTreeMap, BTreeSet}; use std::net::IpAddr; use std::num::NonZeroU32; use uuid::Uuid; -use cql3_parser::common::{DataTypeName, Operand}; enum Metadata { Cassandra(CassandraMetadata), @@ -36,7 +36,6 @@ enum Metadata { None, } - pub type Messages = Vec; /// The Message type is designed to effeciently abstract over the message being in various states of processing. @@ -178,12 +177,10 @@ impl Message { /// None if the statements do not contain table names. pub fn get_table_names(&mut self) -> Vec { match self.frame() { - Some(Frame::Cassandra(cassandra)) => { - cassandra.get_table_names() - } + Some(Frame::Cassandra(cassandra)) => cassandra.get_table_names(), Some(Frame::Redis(_)) => unimplemented!(), Some(Frame::None) => vec![], - _ => unreachable!() + _ => unreachable!(), } } @@ -461,62 +458,75 @@ impl From<&MessageValue> for Operand { fn from(v: &MessageValue) -> Self { match v { MessageValue::NULL => Operand::Null, - MessageValue::Bytes(b) => Operand::from(b) , - MessageValue::Ascii( s ) | - MessageValue::Varchar(s) | - MessageValue::Strings(s) => Operand::from(s.as_str()), + MessageValue::Bytes(b) => Operand::from(b), + MessageValue::Ascii(s) | MessageValue::Varchar(s) | MessageValue::Strings(s) => { + Operand::from(s.as_str()) + } MessageValue::Integer(i, _) => Operand::from(i), MessageValue::Float(f) => Operand::from(&f.0), MessageValue::Boolean(b) => Operand::from(b), MessageValue::Double(d) => Operand::from(&d.0), - MessageValue::Inet(i) => Operand::from( i ), - MessageValue::Varint( i ) => Operand::from(i), - MessageValue::Decimal(d ) => Operand::from(d), - MessageValue::Date( d) => Operand::from(d), - MessageValue::Time(t) | - MessageValue::Counter(t) | - MessageValue::Timestamp(t) => Operand::from(t), - MessageValue::Uuid(u) | - MessageValue::Timeuuid( u) => Operand::from( u ), - - MessageValue::List(l) => {Operand::List( l.iter().map( |x| Operand::from(x).to_string()).collect())} - - MessageValue::Rows(r) => { - Operand::Tuple( r.iter().map( |row| row.iter().map( Operand::from ).collect()).map( Operand::Tuple ).collect()) + MessageValue::Inet(i) => Operand::from(i), + MessageValue::Varint(i) => Operand::from(i), + MessageValue::Decimal(d) => Operand::from(d), + MessageValue::Date(d) => Operand::from(d), + MessageValue::Time(t) | MessageValue::Counter(t) | MessageValue::Timestamp(t) => { + Operand::from(t) } + MessageValue::Uuid(u) | MessageValue::Timeuuid(u) => Operand::from(u), - MessageValue::NamedRows(r) => { - Operand::Tuple( r.iter().map( |nr| Operand::Map(nr.iter().map( |(k,v)| (k.clone(), Operand::from(v).to_string())).collect())).collect()) + MessageValue::List(l) => { + Operand::List(l.iter().map(|x| Operand::from(x).to_string()).collect()) } - MessageValue::Set(s) => { - Operand::Set( s.iter().map( |m| Operand::from(m).to_string()).collect()) - } - MessageValue::Map(m) => { - Operand::Map( m.iter().map( |(k,v)| (Operand::from(k).to_string(), Operand::from(v).to_string())).collect() ) - } + MessageValue::Rows(r) => Operand::Tuple( + r.iter() + .map(|row| row.iter().map(Operand::from).collect()) + .map(Operand::Tuple) + .collect(), + ), - MessageValue::FragmentedResponse(t) | - MessageValue::Tuple(t) => { - Operand::Tuple( t.iter().map( Operand::from ).collect()) - } + MessageValue::NamedRows(r) => Operand::Tuple( + r.iter() + .map(|nr| { + Operand::Map( + nr.iter() + .map(|(k, v)| (k.clone(), Operand::from(v).to_string())) + .collect(), + ) + }) + .collect(), + ), - MessageValue::Udt(d) | - MessageValue::Document(d) => { - Operand::Map( d.iter().map( |(k,v)| (k.clone(),Operand::from(v).to_string())).collect()) + MessageValue::Set(s) => { + Operand::Set(s.iter().map(|m| Operand::from(m).to_string()).collect()) } + MessageValue::Map(m) => Operand::Map( + m.iter() + .map(|(k, v)| (Operand::from(k).to_string(), Operand::from(v).to_string())) + .collect(), + ), - MessageValue::None => { - Operand::Null + MessageValue::FragmentedResponse(t) | MessageValue::Tuple(t) => { + Operand::Tuple(t.iter().map(Operand::from).collect()) } + MessageValue::Udt(d) | MessageValue::Document(d) => Operand::Map( + d.iter() + .map(|(k, v)| (k.clone(), Operand::from(v).to_string())) + .collect(), + ), + + MessageValue::None => Operand::Null, } } } impl From<&Operand> for MessageValue { fn from(operand: &Operand) -> Self { - operand.as_cassandra_type().map_or( MessageValue::None, MessageValue::create_element) + operand + .as_cassandra_type() + .map_or(MessageValue::None, MessageValue::create_element) } } @@ -531,10 +541,10 @@ impl From<&MessageValue> for DataTypeName { //DataTypeName::Int IntSize::I64 => DataTypeName::BigInt, IntSize::I32 => DataTypeName::Int, - IntSize::I16 => DataTypeName::SmallInt, + IntSize::I16 => DataTypeName::SmallInt, IntSize::I8 => DataTypeName::TinyInt, } - }, + } MessageValue::Double(_) => DataTypeName::Double, MessageValue::Float(_) => DataTypeName::Float, MessageValue::Boolean(_) => DataTypeName::Boolean, @@ -557,11 +567,10 @@ impl From<&MessageValue> for DataTypeName { MessageValue::Counter(_) => DataTypeName::Counter, MessageValue::Tuple(_) => DataTypeName::Tuple, MessageValue::Udt(_) => DataTypeName::Tuple, - MessageValue::NULL => DataTypeName::Custom( "NULL".to_string()), - MessageValue::None => DataTypeName::Custom( "None".to_string()), + MessageValue::NULL => DataTypeName::Custom("NULL".to_string()), + MessageValue::None => DataTypeName::Custom("None".to_string()), } } - } impl From for MessageValue { diff --git a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs index 26c0ea5e6..bc801aaf3 100644 --- a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs @@ -1,4 +1,4 @@ -use crate::frame::{CassandraOperation, CassandraResult, Frame}; +use crate::frame::{CassandraOperation, CassandraResult, Frame, CQL}; use crate::message::{IntSize, Message, MessageValue}; use crate::{ error::ChainResponse, @@ -73,18 +73,22 @@ fn extract_native_port_column(message: &mut Message) -> Vec { if let Some(Frame::Cassandra(cassandra)) = message.frame() { if let CassandraOperation::Query { query, .. } = &cassandra.operation { if let CassandraStatement::Select(select) = &query.statement { - select - .columns - .iter() - .for_each(|select_element| match select_element { - SelectElement::Column(col_name) => { - if col_name.name.eq("native_port") { - result.push(col_name.alias_or_name()); + if let Some(table_name) = CQL::get_table_name(&query.statement) { + if table_name.eq("system.peers_v2") { + select + .columns + .iter() + .for_each(|select_element| match select_element { + SelectElement::Column(col_name) => { + if col_name.name.eq("native_port") { + result.push(col_name.alias_or_name()); + } } - } - SelectElement::Star => result.push("native_port".to_string()), - _ => {} - }); + SelectElement::Star => result.push("native_port".to_string()), + _ => {} + }); + } + } } } } @@ -95,15 +99,21 @@ fn extract_native_port_column(message: &mut Message) -> Vec { /// Only Cassandra queries to the `system.peers` table found via the `is_system_peers` function should be passed to this fn rewrite_port(message: &mut Message, column_names: &[String], new_port: u32) { if let Some(Frame::Cassandra(frame)) = message.frame() { - if let CassandraOperation::Result(CassandraResult::Rows { value, metadata }) = &mut frame.operation + if let CassandraOperation::Result(CassandraResult::Rows { value, metadata }) = + &mut frame.operation { - let port_column_index : Vec= metadata + let port_column_index: Vec = metadata .col_specs - .iter().enumerate() - .filter_map(|(idx, col)| if column_names.contains(&col.name) { - Some(idx) - } else { None } - ).collect(); + .iter() + .enumerate() + .filter_map(|(idx, col)| { + if column_names.contains(&col.name) { + Some(idx) + } else { + None + } + }) + .collect(); if let MessageValue::Rows(rows) = value { for row in rows.iter_mut() { diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index 46da3753b..ec8377df2 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -1,5 +1,5 @@ use crate::error::ChainResponse; -use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; +use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame, CQL}; use crate::message::MessageValue; use crate::transforms::protect::key_management::{KeyManager, KeyManagerConfig}; use crate::transforms::{Transform, Transforms, Wrapper}; @@ -7,19 +7,17 @@ use anyhow::anyhow; use anyhow::Result; use async_trait::async_trait; use bytes::Bytes; -use serde::{Deserialize, Serialize}; -use sodiumoxide::crypto::secretbox; -use sodiumoxide::crypto::secretbox::{Key, Nonce}; -use std::collections::HashMap; use cql3_parser::cassandra_statement::CassandraStatement; use cql3_parser::common::Operand; use cql3_parser::insert::InsertValues; use cql3_parser::select::SelectElement; +use serde::{Deserialize, Serialize}; +use sodiumoxide::crypto::secretbox; +use sodiumoxide::crypto::secretbox::{Key, Nonce}; use sodiumoxide::hex; -use sqlparser::test_utils::table; +use std::collections::HashMap; use tracing::warn; - mod aws_kms; mod key_management; mod local_kek; @@ -100,10 +98,10 @@ impl From<&Protected> for Operand { Protected::Plaintext(_) => panic!( "tried to move unencrypted value to plaintext without explicitly calling decrypt" ), - Protected::Ciphertext { .. } => { - Operand::Const( format!( "0X{}", - hex::encode(serde_json::to_vec(&p).unwrap()))) - } + Protected::Ciphertext { .. } => Operand::Const(format!( + "0X{}", + hex::encode(serde_json::to_vec(&p).unwrap()) + )), } } } @@ -175,34 +173,48 @@ impl ProtectConfig { } /// determines if columns in the CassandraStatement need to be encrypted and encrypts them. Returns `true` if any columns were changed. -async fn encrypt_columns( statement : &mut CassandraStatement, columns : &[String], key_source : &KeyManager, key_id : &str) -> Result { - +async fn encrypt_columns( + statement: &mut CassandraStatement, + columns: &[String], + key_source: &KeyManager, + key_id: &str, +) -> Result { let mut data_changed = false; match statement { CassandraStatement::Insert(insert) => { - let indices :Vec= insert.columns.iter().enumerate() - .filter_map( |(i,col_name)| if columns.contains( col_name ) { Some(i) } else { None }) + let indices: Vec = insert + .columns + .iter() + .enumerate() + .filter_map(|(i, col_name)| { + if columns.contains(col_name) { + Some(i) + } else { + None + } + }) .collect(); match &mut insert.values { InsertValues::Values(value_operands) => { for idx in indices { - let mut protected = Protected::Plaintext(MessageValue::from( &value_operands[idx] )); + let mut protected = + Protected::Plaintext(MessageValue::from(&value_operands[idx])); protected = protected.protect(key_source, key_id).await?; - value_operands[idx] = Operand::from( &protected ); + value_operands[idx] = Operand::from(&protected); data_changed = true - } - }, + } + } InsertValues::Json(_) => { // TODO parse json and encrypt. } } } CassandraStatement::Update(update) => { - for assignment in &mut update.assignments { - if columns.contains( &assignment.name.column ) { - let mut protected = Protected::Plaintext( MessageValue::from(&assignment.value) ); - protected = protected.protect( key_source, key_id ).await?; - assignment.value = Operand::from( &protected ); + for assignment in &mut update.assignments { + if columns.contains(&assignment.name.column) { + let mut protected = Protected::Plaintext(MessageValue::from(&assignment.value)); + protected = protected.protect(key_source, key_id).await?; + assignment.value = Operand::from(&protected); data_changed = true; } } @@ -222,16 +234,21 @@ impl Transform for Protect { let mut data_changed = false; if let Some(Frame::Cassandra(CassandraFrame { - operation: CassandraOperation::Query { query, .. }, - .. - })) = message.frame() + operation: CassandraOperation::Query { query, .. }, + .. + })) = message.frame() { - if let Some(table_name) = query.get_table_name() { - if let Some((_, tables)) = - self.keyspace_table_columns.get_key_value(table_name) + if let Some(table_name) = CQL::get_table_name(&query.statement) { + if let Some((_, tables)) = self.keyspace_table_columns.get_key_value(table_name) { - if let Some((_, columns)) = tables.get_key_value(table_name ) { - data_changed = encrypt_columns(&mut query.statement, columns, &self.key_source, &self.key_id).await?; + if let Some((_, columns)) = tables.get_key_value(table_name) { + data_changed = encrypt_columns( + &mut query.statement, + columns, + &self.key_source, + &self.key_id, + ) + .await?; } } } @@ -257,23 +274,25 @@ impl Transform for Protect { })) = response.frame() { if let Some(Frame::Cassandra(CassandraFrame { - operation: CassandraOperation::Query { query, .. }, - .. - })) = request.frame() + operation: CassandraOperation::Query { query, .. }, + .. + })) = request.frame() { - if let Some(table_name) = query.get_table_name() { + if let Some(table_name) = CQL::get_table_name(&query.statement) { if let Some((_keyspace, tables)) = - self.keyspace_table_columns.get_key_value(table_name) + self.keyspace_table_columns.get_key_value(table_name) { if let Some((_table, protect_columns)) = - tables.get_key_value(table_name) + tables.get_key_value(table_name) { if let CassandraStatement::Select(select) = &query.statement { - let positions: Vec = select.columns.iter().enumerate() + let positions: Vec = select + .columns + .iter() + .enumerate() .filter_map(|(i, col)| { if let SelectElement::Column(named) = col { - if protect_columns.contains(&named.name) - { + if protect_columns.contains(&named.name) { Some(i) } else { None @@ -281,7 +300,8 @@ impl Transform for Protect { } else { None } - }).collect(); + }) + .collect(); for row in &mut *rows { for index in &positions { if let Some(v) = row.get_mut(*index) { @@ -316,4 +336,3 @@ impl Transform for Protect { Ok(result) } } - diff --git a/shotover-proxy/src/transforms/query_counter.rs b/shotover-proxy/src/transforms/query_counter.rs index ae6cd2034..4c674dac1 100644 --- a/shotover-proxy/src/transforms/query_counter.rs +++ b/shotover-proxy/src/transforms/query_counter.rs @@ -7,7 +7,6 @@ use async_trait::async_trait; use metrics::{counter, register_counter}; use serde::Deserialize; - #[derive(Debug, Clone)] pub struct QueryCounter { counter_name: String, diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index f2ae283d1..7d9cef967 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -1,7 +1,7 @@ use crate::config::topology::TopicHolder; use crate::error::ChainResponse; -use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame, RedisFrame}; -use crate::message::{Message, Messages, QueryType}; +use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame, RedisFrame, CQL}; +use crate::message::{Message, MessageValue, Messages, QueryType}; use crate::transforms::chain::TransformChain; use crate::transforms::{ build_chain_from_config, Transform, Transforms, TransformsConfig, Wrapper, @@ -10,13 +10,11 @@ use anyhow::{anyhow, bail, Result}; use async_trait::async_trait; use bytes::{BufMut, Bytes, BytesMut}; use cassandra_protocol::frame::Version; +use cql3_parser::cassandra_statement::CassandraStatement; +use cql3_parser::common::{Operand, RelationOperator}; +use itertools::Itertools; use serde::Deserialize; use std::collections::{BTreeMap, HashMap}; -use cql3_parser::cassandra_statement::CassandraStatement; -use cql3_parser::common::{Operand, PrimaryKey, RelationElement, RelationOperator, WhereClause}; - -const TRUE: [u8; 1] = [0x1]; -const FALSE: [u8; 1] = [0x0]; #[derive(Deserialize, Debug, Clone)] pub struct RedisConfig { @@ -30,13 +28,6 @@ pub struct TableCacheSchema { range_key: Vec, } - -impl From<&PrimaryKey> for TableCacheSchema { - fn from(value: &PrimaryKey) -> TableCacheSchema { - TableCacheSchema { partition_key: value.partition.clone(), range_key: value.clustering.clone() } - } -} - impl RedisConfig { pub async fn get_transform(&self, topics: &TopicHolder) -> Result { Ok(Transforms::RedisCache(SimpleRedisCache { @@ -60,7 +51,7 @@ impl SimpleRedisCache { async fn get_or_update_from_cache( &mut self, - mut messages_cass_request: Messages + mut messages_cass_request: Messages, ) -> ChainResponse { // This function is a little hard to follow, so here's an overview. // We have 4 vecs of messages, each vec can be considered its own stage of processing. @@ -81,76 +72,60 @@ impl SimpleRedisCache { // * These are the cassandra responses that we return from the function. let mut messages_redis_request = Vec::with_capacity(messages_cass_request.len()); - // process only as long as all the cass requests can be answered with from the cache. - for cass_request in &mut messages_cass_request { match cass_request.frame() { Some(Frame::Cassandra(frame)) => { - // get the statements that have table names - if let CassandraOperation::Query { query, params } = &frame.operation { - if let Some(table_name) = query.get_table_name() { - - // process the table if it is listed in the caching schema - if let Some(table_cache_schema) = self + for query in frame.operation.queries()? { + if let Some(table_name) = CQL::get_table_name(query) { + let table_cache_schema = self .caching_schema - .get(table_name) { - match query.statement { - CassandraStatement::Insert(_) | - CassandraStatement::Update(_) | - CassandraStatement::Delete(_) => { - let result = build_redis_ast_from_cql3(&query.statement, table_cache_schema); - if result.is_ok() { - messages_redis_request.push(Message::from_frame(Frame::Redis(result.unwrap()))); - } - }, - _ => {} - } - } + .get(table_name) + .ok_or_else(|| anyhow!("{table_name} not a caching table"))?; + + messages_redis_request.push(Message::from_frame(Frame::Redis( + build_redis_ast_from_cql3(query, table_cache_schema)?, + ))); } } - }, - _ => { - // statements that we can not handle } + message => bail!("cannot fetch {message:?} from cache"), } } - // if we can handle all the query from the cache then do so - if ! &messages_redis_request.is_empty() { - let messages_redis_response = self - .cache_chain - .process_request( - Wrapper::new_with_chain_name(messages_redis_request, self.cache_chain.name.clone()), - "clientdetailstodo".to_string(), - ) - .await?; - - // Replace cass_request messages with cassandra responses in place. - // We reuse the vec like this to save allocations. - let mut messages_redis_response_iter = messages_redis_response.into_iter(); - for cass_request in &mut messages_cass_request { - let mut redis_responses = vec![]; - if let Some(Frame::Cassandra(frame)) = cass_request.frame() { - if let Ok(queries) = frame.operation.queries() { - for _query in queries { - redis_responses.push(messages_redis_response_iter.next()); - } - } - } - // TODO: Translate the redis_responses into a cassandra result - *cass_request = Message::from_frame(Frame::Cassandra(CassandraFrame { - version: Version::V4, - operation: CassandraOperation::Result(CassandraResult::Void), - stream_id: cass_request.stream_id().unwrap(), - tracing_id: None, - warnings: vec![], - })); + let messages_redis_response = self + .cache_chain + .process_request( + Wrapper::new_with_chain_name(messages_redis_request, self.cache_chain.name.clone()), + "clientdetailstodo".to_string(), + ) + .await?; + + // Replace cass_request messages with cassandra responses in place. + // We reuse the vec like this to save allocations. + let mut messages_redis_response_iter = messages_redis_response.into_iter(); + for cass_request in &mut messages_cass_request { + let mut redis_responses = vec![]; + if let Some(Frame::Cassandra(frame)) = cass_request.frame() { + if let Ok(queries) = frame.operation.queries() { + for _query in queries { + redis_responses.push(messages_redis_response_iter.next()); + } } } + // TODO: Translate the redis_responses into a cassandra result + *cass_request = Message::from_frame(Frame::Cassandra(CassandraFrame { + version: Version::V4, + operation: CassandraOperation::Result(CassandraResult::Void), + stream_id: cass_request.stream_id().unwrap(), + tracing_id: None, + warnings: vec![], + })); + } Ok(messages_cass_request) } } + fn append_prefix_min(min: &mut Vec) { if min.is_empty() { min.push(b'['); @@ -167,32 +142,19 @@ fn append_prefix_max(max: &mut Vec) { } } -fn build_zrangebylex_min_max_from_cql3( - operator : &RelationOperator, +fn operand_to_bytes(operand: &Operand) -> Vec { + let message_value = MessageValue::from(operand); + let bytes: cassandra_protocol::types::value::Bytes = message_value.into(); + bytes.into_inner() +} + +fn build_zrangebylex_min_max_from_sql( + operator: &RelationOperator, operand: &Operand, min: &mut Vec, max: &mut Vec, ) -> Result<()> { - - let mut bytes = - match operand { - Operand::Const(value) => { - Vec::from( - match value.to_uppercase().as_str() { - "TRUE" => &TRUE, - "FALSE" => &FALSE, - _ => value.as_bytes(), - }) - } - Operand::Map(_) | - Operand::Set(_) | - Operand::List(_) | - Operand::Tuple(_) | - Operand::Column(_) | - Operand::Func(_) => Vec::from(operand.to_string().as_bytes()), - Operand::Null => vec!(), - }; - + let mut bytes = operand_to_bytes(operand); match operator { RelationOperator::LessThan => { let last_byte = bytes.last_mut().unwrap(); @@ -223,212 +185,184 @@ fn build_zrangebylex_min_max_from_cql3( min.extend(bytes.iter()); } // should "IN"" be converted to an "or" "eq" combination - - RelationOperator::NotEqual | - RelationOperator::In | - RelationOperator::Contains | - RelationOperator::ContainsKey | - RelationOperator::IsNot => { + RelationOperator::NotEqual + | RelationOperator::In + | RelationOperator::Contains + | RelationOperator::ContainsKey + | RelationOperator::IsNot => { return Err(anyhow!("Couldn't build query")); } } Ok(()) } -fn build_redis_frames_from_where_clause( where_clause : &[RelationElement], table_cache_schema: &TableCacheSchema) -> Result> { - let mut min: Vec = Vec::new(); - let mut max: Vec = Vec::new(); - let mut had_err = None; - - let where_columns = WhereClause::get_column_relation_element_map( where_clause ); - - // process the partition key - where_columns.iter().filter(|(name,_relation_elements)| { - ! table_cache_schema.partition_key.contains( name ) - }).for_each( |(_name,relation_elements)| { - for relation_element in relation_elements { - for operand in &relation_element.value { - let x = build_zrangebylex_min_max_from_cql3(&relation_element.oper, operand, &mut min, &mut max, ); - if x.is_err() { - had_err = x.err() - } - } - } - }); - - if let Some(e) = had_err { - return Err(e); - } - let min = if min.is_empty() { - Bytes::from_static(b"-") - } else { - Bytes::from(min) - }; - let max = if max.is_empty() { - Bytes::from_static(b"+") - } else { - Bytes::from(max) - }; - - let pk = table_cache_schema - .partition_key - .iter() - .filter_map(|k| { - let x = where_columns.get(k); - x?; - let y = x.unwrap().iter().find(|x| x.oper == RelationOperator::Equal); - y?; - Some(&y.unwrap().value) - }) - .fold(BytesMut::new(), |mut acc, v| { - v.iter().for_each(|operand| acc.extend(operand.to_string().as_bytes())); - acc - }); - Ok(vec![ - RedisFrame::BulkString("ZRANGEBYLEX".into()), - RedisFrame::BulkString(pk.freeze()), - RedisFrame::BulkString(min), - RedisFrame::BulkString(max), - ]) -} - -fn extract_partition_key( partition_key_columns : &[String], value_map : &BTreeMap) -> Result{ - let pk = partition_key_columns - .iter() - .map(|k| - value_map.get(k.as_str()).unwrap() - ) - .fold(BytesMut::new(), |mut acc, v| { - acc.extend(v.to_string().as_bytes()); - acc - }); - Ok(pk) -} - -fn build_redis_ast_from_cql3 ( +fn build_redis_ast_from_cql3( statement: &CassandraStatement, table_cache_schema: &TableCacheSchema, -) -> Result -{ - match statement { - CassandraStatement::Select(select) => { - if select.where_clause.is_some() { - Ok(RedisFrame::Array( build_redis_frames_from_where_clause( select.where_clause.as_ref().unwrap(),table_cache_schema)?)) - } else { - Err(anyhow!("Can't build query from statement: {}", statement)) +) -> Result { + match statement { + CassandraStatement::Select(select) => { + if select.filtering || !select.columns.is_empty() || !select.where_clause.is_empty() { + Err(anyhow!("Can't build query from expr: {}", select)) + } else { + let mut min: Vec = Vec::new(); + let mut max: Vec = Vec::new(); + + // extract the partition and range operands + // fail if any are missing + let mut partition_segments: HashMap<&str, &Operand> = HashMap::new(); + let mut range_segments: HashMap<&str, (&RelationOperator, &Operand)> = + HashMap::new(); + + for relation_element in &select.where_clause { + if let Operand::Column(column_name) = &relation_element.obj { + // name has to be in partition or range key. + if table_cache_schema.partition_key.contains(&column_name) { + partition_segments.insert(&column_name, &relation_element.value); + } else if table_cache_schema.range_key.contains(&column_name) { + range_segments.insert( + &column_name, + (&relation_element.oper, &relation_element.value), + ); + } else { + return Err(anyhow!( + "Couldn't build query- column {} is not in the key", + column_name + )); + } + } } - } - CassandraStatement::Insert(insert) => { - // partition key from the value map - // values from the remaining parts of the value map. - let value_map : BTreeMap = insert.get_value_map(); - let pk = extract_partition_key( &table_cache_schema.partition_key, &value_map )?; - let mut redis_frames: Vec = vec![ - RedisFrame::BulkString("ZADD".into()), - RedisFrame::BulkString(pk.freeze()), - ]; - add_values_to_redis_frames(table_cache_schema, value_map, &mut redis_frames)?; - Ok(RedisFrame::Array(redis_frames)) - } - CassandraStatement::Update(update) => { - // only want the partition key built from `equals` statements in the where clause - // and values from the set clause - let where_tree = WhereClause::get_column_relation_element_map(&update.where_clause); - let mut value_map :BTreeMap = where_tree - .iter().filter_map( |(k,v)| { - for relation in v { - if relation.oper == RelationOperator::Equal && relation.value.len() == 1 { - return Some((k.clone(),&relation.value[0])); - } + let mut skipping = false; + for column_name in &table_cache_schema.range_key { + if let Some((operator, operand)) = range_segments.get(column_name.as_str()) { + if skipping { + // we skipped an earlier column so this is an error. + return Err(anyhow!( + "Columns in the middle of the range key were skipped" + )); + } + if let Err(e) = build_zrangebylex_min_max_from_sql( + &operator, &operand, &mut min, &mut max, + ) { + return Err(e); } - None - }).collect(); - let mut has_err = false; - update.assignments.iter().for_each( |assignment| { - if assignment.operator.is_some() { - has_err = true; } else { - value_map.insert( assignment.name.to_string(), &assignment.value ); + // once we skip a range key column we have to skip all the rest so set a flag. + skipping = true; + } + } + let min = if min.is_empty() { + Bytes::from_static(b"-") + } else { + Bytes::from(min) + }; + let max = if max.is_empty() { + Bytes::from_static(b"+") + } else { + Bytes::from(max) + }; + + let mut partition_key = BytesMut::new(); + for column_name in &table_cache_schema.partition_key { + if let Some(operand) = partition_segments.get(column_name.as_str()) { + partition_key.extend(operand_to_bytes(operand).iter()); + } else { + return Err(anyhow!("partition column {} missing", column_name)); } - }); - if has_err { - return Err(anyhow!("Set values include operations")); } - let pk = extract_partition_key( &table_cache_schema.partition_key, &value_map )?; - let mut redis_frames: Vec = vec![ - RedisFrame::BulkString("ZADD".into()), - RedisFrame::BulkString(pk.freeze()), + + let commands_buffer = vec![ + RedisFrame::BulkString("ZRANGEBYLEX".into()), + RedisFrame::BulkString(partition_key.freeze()), + RedisFrame::BulkString(min), + RedisFrame::BulkString(max), ]; - add_values_to_redis_frames(table_cache_schema, value_map, &mut redis_frames)?; - Ok(RedisFrame::Array(redis_frames)) + Ok(RedisFrame::Array(commands_buffer)) } - _ => unreachable!(), } + CassandraStatement::Insert(insert) => { + let query_values = insert.get_value_map(); + add_query_values(table_cache_schema, query_values) + } + CassandraStatement::Update(update) => { + let mut query_values: BTreeMap<&str, &Operand> = BTreeMap::new(); + for assignment_element in &update.assignments { + if assignment_element.operator.is_some() { + return Err(anyhow!("Update has calculations in values")); + } + if assignment_element.name.idx.is_some() { + return Err(anyhow!("Update has indexed columns")); + } + query_values.insert( + assignment_element.name.column.as_str(), + &assignment_element.value, + ); + } + add_query_values(table_cache_schema, query_values) + } + statement => Err(anyhow!("Cant build query from statement: {}", statement)), + } } -fn add_values_to_redis_frames( +fn add_query_values( table_cache_schema: &TableCacheSchema, - query_values: BTreeMap, - redis_frames : &mut Vec -) -> Result<()> { + query_values: BTreeMap<&str, &Operand>, +) -> Result { + let mut partition_key = BytesMut::new(); + for column_name in &table_cache_schema.partition_key { + if let Some(operand) = query_values.get(column_name.as_str()) { + partition_key.extend(operand_to_bytes(operand).iter()); + } else { + return Err(anyhow!("partition column {} missing", column_name)); + } + } + + let mut clustering = BytesMut::new(); + for column_name in &table_cache_schema.range_key { + if let Some(operand) = query_values.get(column_name.as_str()) { + clustering.extend(operand_to_bytes(operand).iter()); + } else { + return Err(anyhow!("range column {} missing", column_name)); + } + } + + let mut commands_buffer: Vec = vec![ + RedisFrame::BulkString("ZADD".into()), + RedisFrame::BulkString(partition_key.freeze()), + ]; - let mut has_err = None; - let mut clustering = table_cache_schema - .range_key + let values = query_values .iter() - .filter_map(|k| { - if let Some(x) = query_values.get(k.as_str()) { - Some(x) + .filter_map(|(column_name, value)| { + if !table_cache_schema + .partition_key + .contains(&column_name.to_string()) + && !table_cache_schema + .range_key + .contains(&column_name.to_string()) + { + Some(value) } else { - has_err = Some(anyhow!( "Clustering column {} missing from statement", k )); None } }) - .fold(BytesMut::new(), |mut acc, operand| { - acc.extend(operand.to_string().as_bytes()); - acc - }); - if let Some(e) = has_err { - return Err(e); - } - if !clustering.is_empty() { - clustering.put_u8(b':'); - } - redis_frames.push(RedisFrame::BulkString(Bytes::from_static(b"0"))); + .collect_vec(); - query_values - .iter() - .filter_map(|(p, v)| { - if table_cache_schema.partition_key.contains(p) || - table_cache_schema.range_key.contains(p) { - None - } else { Some(*v)} - }) - .for_each( |operand| { - clustering.extend(operand.to_string().as_bytes()); - }); - redis_frames.push(RedisFrame::BulkString(clustering.freeze())); + for operand in values { + commands_buffer.push(RedisFrame::BulkString(Bytes::from_static(b"0"))); + let mut value = clustering.clone(); + if !value.is_empty() { + value.put_u8(b':'); + } + value.extend(operand_to_bytes(operand).iter()); + commands_buffer.push(RedisFrame::BulkString(value.freeze())); + } - Ok(()) + Ok(RedisFrame::Array(commands_buffer)) } #[async_trait] impl Transform for SimpleRedisCache { - fn validate(&self) -> Vec { - let mut errors = self - .cache_chain - .validate() - .iter() - .map(|x| format!(" {x}")) - .collect::>(); - - if !errors.is_empty() { - errors.insert(0, format!("{}:", self.get_name())); - } - - errors - } - async fn transform<'a>(&'a mut self, mut message_wrapper: Wrapper<'a>) -> ChainResponse { let mut updates = false; @@ -465,11 +399,26 @@ impl Transform for SimpleRedisCache { upstream } } + + fn validate(&self) -> Vec { + let mut errors = self + .cache_chain + .validate() + .iter() + .map(|x| format!(" {x}")) + .collect::>(); + + if !errors.is_empty() { + errors.insert(0, format!("{}:", self.get_name())); + } + + errors + } } #[cfg(test)] mod test { - use crate::frame::RedisFrame; + use crate::frame::{RedisFrame, CQL}; use crate::transforms::chain::TransformChain; use crate::transforms::debug::printer::DebugPrinter; use crate::transforms::null::Null; @@ -478,14 +427,13 @@ mod test { }; use crate::transforms::{Transform, Transforms}; use bytes::Bytes; - use std::collections::HashMap; - use cql3_parser::cassandra_ast::CassandraAST; use cql3_parser::cassandra_statement::CassandraStatement; + use std::collections::HashMap; fn build_query(query_string: &str) -> CassandraStatement { - let ast = CassandraAST::new( query_string ); - assert!( !ast.has_error() ); - ast.statements[0].clone() + let cql = CQL::parse_from_string(query_string); + assert!(!cql.has_error); + cql.statement } #[test] @@ -505,6 +453,7 @@ mod test { RedisFrame::BulkString(Bytes::from_static(b"[123:965")), RedisFrame::BulkString(Bytes::from_static(b"]123:965")), ]); + assert_eq!(expected, query); } @@ -543,7 +492,7 @@ mod test { RedisFrame::BulkString(Bytes::from_static(b"ZADD")), RedisFrame::BulkString(Bytes::from_static(b"1")), RedisFrame::BulkString(Bytes::from_static(b"0")), - RedisFrame::BulkString(Bytes::from_static(b"'yo':123")), + RedisFrame::BulkString(Bytes::from_static(b"yo:123")), ]); assert_eq!(expected, query); @@ -564,7 +513,7 @@ mod test { RedisFrame::BulkString(Bytes::from_static(b"ZADD")), RedisFrame::BulkString(Bytes::from_static(b"1")), RedisFrame::BulkString(Bytes::from_static(b"0")), - RedisFrame::BulkString(Bytes::from_static(b"'yo':123")), + RedisFrame::BulkString(Bytes::from_static(b"yo:123")), ]); assert_eq!(expected, query); @@ -587,7 +536,7 @@ mod test { // Semantically databases treat the order of AND clauses differently, Cassandra however requires clustering key predicates be in order // So here we will just expect the order is correct in the query. TODO: we may need to revisit this as support for other databases is added - assert_eq!(query_one, query_two); + assert_ne!(query_one, query_two); } #[test] From 3125826962490df87cb973a09154caa669686067 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Thu, 7 Apr 2022 07:10:08 +0100 Subject: [PATCH 09/60] cache tests working --- shotover-proxy/src/frame/cassandra.rs | 7 +- shotover-proxy/src/transforms/redis/cache.rs | 208 +++++++++++-------- 2 files changed, 126 insertions(+), 89 deletions(-) diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index 1b81da680..82f51a8b4 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -742,9 +742,8 @@ impl CQL { } /// Returns true if there are any parameters in the query - pub fn has_params(&self) -> bool { - let mut statement = self.statement.clone(); - match &mut statement { + pub fn has_params(statement: &CassandraStatement) -> bool { + match statement { CassandraStatement::Delete(delete) => { if CQL::has_params_in_relation_elements(&delete.where_clause) { return true; @@ -754,7 +753,7 @@ impl CQL { } } CassandraStatement::Insert(insert) => { - if let InsertValues::Values(operands) = &mut insert.values { + if let InsertValues::Values(operands) = &insert.values { for operand_idx in 0..operands.len() { if let Operand::Param(_) = &operands[operand_idx] { return true; diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index 7d9cef967..233497267 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -1,7 +1,7 @@ use crate::config::topology::TopicHolder; use crate::error::ChainResponse; use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame, RedisFrame, CQL}; -use crate::message::{Message, MessageValue, Messages, QueryType}; +use crate::message::{Message, Messages, QueryType}; use crate::transforms::chain::TransformChain; use crate::transforms::{ build_chain_from_config, Transform, Transforms, TransformsConfig, Wrapper, @@ -11,7 +11,8 @@ use async_trait::async_trait; use bytes::{BufMut, Bytes, BytesMut}; use cassandra_protocol::frame::Version; use cql3_parser::cassandra_statement::CassandraStatement; -use cql3_parser::common::{Operand, RelationOperator}; +use cql3_parser::common::{Operand, RelationElement, RelationOperator}; +use cql3_parser::select::SelectElement; use itertools::Itertools; use serde::Deserialize; use std::collections::{BTreeMap, HashMap}; @@ -142,19 +143,13 @@ fn append_prefix_max(max: &mut Vec) { } } -fn operand_to_bytes(operand: &Operand) -> Vec { - let message_value = MessageValue::from(operand); - let bytes: cassandra_protocol::types::value::Bytes = message_value.into(); - bytes.into_inner() -} - fn build_zrangebylex_min_max_from_sql( operator: &RelationOperator, operand: &Operand, min: &mut Vec, max: &mut Vec, ) -> Result<()> { - let mut bytes = operand_to_bytes(operand); + let mut bytes = BytesMut::from(operand.to_string().as_bytes()); match operator { RelationOperator::LessThan => { let last_byte = bytes.last_mut().unwrap(); @@ -196,89 +191,120 @@ fn build_zrangebylex_min_max_from_sql( Ok(()) } +fn is_cacheable(statement: &CassandraStatement) -> bool { + CQL::has_params(statement) + || match statement { + CassandraStatement::Select(select) => { + if select.filtering || select.where_clause.is_empty() { + return false; + } + if !select.columns.is_empty() { + if select.columns.len() == 1 && select.columns[0].eq(&SelectElement::Star) { + // ok + } else { + return false; + } + } + true + } + CassandraStatement::Insert(insert) => !insert.if_not_exists, + CassandraStatement::Update(update) => !update.if_exists, + + _ => false, + } +} + fn build_redis_ast_from_cql3( statement: &CassandraStatement, table_cache_schema: &TableCacheSchema, ) -> Result { + if !is_cacheable(statement) { + return Err(anyhow!("{} is not cacheable", statement)); + } match statement { CassandraStatement::Select(select) => { - if select.filtering || !select.columns.is_empty() || !select.where_clause.is_empty() { - Err(anyhow!("Can't build query from expr: {}", select)) - } else { - let mut min: Vec = Vec::new(); - let mut max: Vec = Vec::new(); - - // extract the partition and range operands - // fail if any are missing - let mut partition_segments: HashMap<&str, &Operand> = HashMap::new(); - let mut range_segments: HashMap<&str, (&RelationOperator, &Operand)> = - HashMap::new(); - - for relation_element in &select.where_clause { - if let Operand::Column(column_name) = &relation_element.obj { - // name has to be in partition or range key. - if table_cache_schema.partition_key.contains(&column_name) { - partition_segments.insert(&column_name, &relation_element.value); - } else if table_cache_schema.range_key.contains(&column_name) { - range_segments.insert( - &column_name, - (&relation_element.oper, &relation_element.value), - ); + let mut min: Vec = Vec::new(); + let mut max: Vec = Vec::new(); + + // extract the partition and range operands + // fail if any are missing + let mut partition_segments: HashMap<&str, &Operand> = HashMap::new(); + let mut range_segments: HashMap<&str, Vec<&RelationElement>> = HashMap::new(); + + for relation_element in &select.where_clause { + if let Operand::Column(column_name) = &relation_element.obj { + // name has to be in partition or range key. + if table_cache_schema.partition_key.contains(&column_name) { + partition_segments.insert(&column_name, &relation_element.value); + } else if table_cache_schema.range_key.contains(&column_name) { + let value = range_segments.get_mut(column_name.as_str()); + let vec = if value.is_none() { + range_segments.insert(&column_name, vec![]); + range_segments.get_mut(column_name.as_str()).unwrap() } else { - return Err(anyhow!( - "Couldn't build query- column {} is not in the key", - column_name - )); - } + value.unwrap() + }; + + vec.push(relation_element); + } else { + return Err(anyhow!( + "Couldn't build query- column {} is not in the key", + column_name + )); } } - let mut skipping = false; - for column_name in &table_cache_schema.range_key { - if let Some((operator, operand)) = range_segments.get(column_name.as_str()) { - if skipping { - // we skipped an earlier column so this is an error. - return Err(anyhow!( - "Columns in the middle of the range key were skipped" - )); - } + } + let mut skipping = false; + for column_name in &table_cache_schema.range_key { + if let Some(relation_elements) = range_segments.get(column_name.as_str()) { + if skipping { + // we skipped an earlier column so this is an error. + return Err(anyhow!( + "Columns in the middle of the range key were skipped" + )); + } + for range_element in relation_elements { if let Err(e) = build_zrangebylex_min_max_from_sql( - &operator, &operand, &mut min, &mut max, + &range_element.oper, + &range_element.value, + &mut min, + &mut max, ) { return Err(e); } - } else { - // once we skip a range key column we have to skip all the rest so set a flag. - skipping = true; } - } - let min = if min.is_empty() { - Bytes::from_static(b"-") - } else { - Bytes::from(min) - }; - let max = if max.is_empty() { - Bytes::from_static(b"+") } else { - Bytes::from(max) - }; - - let mut partition_key = BytesMut::new(); - for column_name in &table_cache_schema.partition_key { - if let Some(operand) = partition_segments.get(column_name.as_str()) { - partition_key.extend(operand_to_bytes(operand).iter()); - } else { - return Err(anyhow!("partition column {} missing", column_name)); - } + // once we skip a range key column we have to skip all the rest so set a flag. + skipping = true; } + } + let min = if min.is_empty() { + Bytes::from_static(b"-") + } else { + Bytes::from(min) + }; + let max = if max.is_empty() { + Bytes::from_static(b"+") + } else { + Bytes::from(max) + }; - let commands_buffer = vec![ - RedisFrame::BulkString("ZRANGEBYLEX".into()), - RedisFrame::BulkString(partition_key.freeze()), - RedisFrame::BulkString(min), - RedisFrame::BulkString(max), - ]; - Ok(RedisFrame::Array(commands_buffer)) + let mut partition_key = BytesMut::new(); + for column_name in &table_cache_schema.partition_key { + if let Some(operand) = partition_segments.get(column_name.as_str()) { + partition_key.extend(operand.to_string().as_bytes()); + } else { + return Err(anyhow!("partition column {} missing", column_name)); + } } + + let commands_buffer = vec![ + RedisFrame::BulkString("ZRANGEBYLEX".into()), + RedisFrame::BulkString(partition_key.freeze()), + RedisFrame::BulkString(min), + RedisFrame::BulkString(max), + ]; + Ok(RedisFrame::Array(commands_buffer)) } CassandraStatement::Insert(insert) => { let query_values = insert.get_value_map(); @@ -298,6 +324,17 @@ fn build_redis_ast_from_cql3( &assignment_element.value, ); } + for relation_element in &update.where_clause { + if relation_element.oper == RelationOperator::Equal { + if let Operand::Column(name) = &relation_element.obj { + if table_cache_schema.partition_key.contains(&name) + || table_cache_schema.range_key.contains(&name) + { + query_values.insert(&name, &relation_element.value); + } + } + } + } add_query_values(table_cache_schema, query_values) } statement => Err(anyhow!("Cant build query from statement: {}", statement)), @@ -311,7 +348,7 @@ fn add_query_values( let mut partition_key = BytesMut::new(); for column_name in &table_cache_schema.partition_key { if let Some(operand) = query_values.get(column_name.as_str()) { - partition_key.extend(operand_to_bytes(operand).iter()); + partition_key.extend(operand.to_string().as_bytes()); } else { return Err(anyhow!("partition column {} missing", column_name)); } @@ -320,7 +357,7 @@ fn add_query_values( let mut clustering = BytesMut::new(); for column_name in &table_cache_schema.range_key { if let Some(operand) = query_values.get(column_name.as_str()) { - clustering.extend(operand_to_bytes(operand).iter()); + clustering.extend(operand.to_string().as_bytes()); } else { return Err(anyhow!("range column {} missing", column_name)); } @@ -331,6 +368,7 @@ fn add_query_values( RedisFrame::BulkString(partition_key.freeze()), ]; + // get values not in partition or cluster key let values = query_values .iter() .filter_map(|(column_name, value)| { @@ -354,7 +392,7 @@ fn add_query_values( if !value.is_empty() { value.put_u8(b':'); } - value.extend(operand_to_bytes(operand).iter()); + value.extend(operand.to_string().as_bytes()); commands_buffer.push(RedisFrame::BulkString(value.freeze())); } @@ -440,7 +478,7 @@ mod test { fn equal_test() { let table_cache_schema = TableCacheSchema { partition_key: vec!["z".to_string()], - range_key: vec![], + range_key: vec!["x".to_string(), "y".to_string()], }; let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x = 123 AND y = 965"); @@ -492,7 +530,7 @@ mod test { RedisFrame::BulkString(Bytes::from_static(b"ZADD")), RedisFrame::BulkString(Bytes::from_static(b"1")), RedisFrame::BulkString(Bytes::from_static(b"0")), - RedisFrame::BulkString(Bytes::from_static(b"yo:123")), + RedisFrame::BulkString(Bytes::from_static(b"'yo':123")), ]); assert_eq!(expected, query); @@ -513,7 +551,7 @@ mod test { RedisFrame::BulkString(Bytes::from_static(b"ZADD")), RedisFrame::BulkString(Bytes::from_static(b"1")), RedisFrame::BulkString(Bytes::from_static(b"0")), - RedisFrame::BulkString(Bytes::from_static(b"yo:123")), + RedisFrame::BulkString(Bytes::from_static(b"'yo':123")), ]); assert_eq!(expected, query); @@ -523,7 +561,7 @@ mod test { fn check_deterministic_order_test() { let table_cache_schema = TableCacheSchema { partition_key: vec!["z".to_string()], - range_key: vec![], + range_key: vec!["x".to_string(), "y".to_string()], }; let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x = 123 AND y = 965"); @@ -536,14 +574,14 @@ mod test { // Semantically databases treat the order of AND clauses differently, Cassandra however requires clustering key predicates be in order // So here we will just expect the order is correct in the query. TODO: we may need to revisit this as support for other databases is added - assert_ne!(query_one, query_two); + assert_eq!(query_one, query_two); } #[test] fn range_exclusive_test() { let table_cache_schema = TableCacheSchema { partition_key: vec!["z".to_string()], - range_key: vec![], + range_key: vec!["x".to_string()], }; let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x > 123 AND x < 999"); @@ -564,7 +602,7 @@ mod test { fn range_inclusive_test() { let table_cache_schema = TableCacheSchema { partition_key: vec!["z".to_string()], - range_key: vec![], + range_key: vec!["x".to_string()], }; let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x >= 123 AND x <= 999"); @@ -627,7 +665,7 @@ mod test { fn open_range_test() { let table_cache_schema = TableCacheSchema { partition_key: vec!["z".to_string()], - range_key: vec![], + range_key: vec!["x".to_string()], }; let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x >= 123"); From 522bad32a8ee64bf313be2640cda01bde7165817 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Thu, 7 Apr 2022 10:02:38 +0100 Subject: [PATCH 10/60] Fixed clippy issues --- shotover-proxy/src/frame/cassandra.rs | 88 ++++++++++--------- .../src/transforms/cassandra/peers_rewrite.rs | 5 +- shotover-proxy/src/transforms/protect/mod.rs | 18 ++-- shotover-proxy/src/transforms/redis/cache.rs | 23 +++-- 4 files changed, 68 insertions(+), 66 deletions(-) diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index 82f51a8b4..a60480e1f 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -251,7 +251,7 @@ impl CassandraFrame { PubSubMessage, */ match &self.operation { - CassandraOperation::Query { query: cql, .. } => match cql.statement { + CassandraOperation::Query { query: cql, .. } => match cql.get_statement() { CassandraStatement::AlterKeyspace(_) => QueryType::SchemaChange, CassandraStatement::AlterMaterializedView(_) => QueryType::SchemaChange, CassandraStatement::AlterRole(_) => QueryType::SchemaChange, @@ -492,11 +492,23 @@ impl CassandraOperation { #[derive(PartialEq, Debug, Clone)] pub struct CQL { - pub statement: CassandraStatement, + statement: Box, pub has_error: bool, } impl CQL { + pub fn get_statement_mut(&mut self) -> &mut CassandraStatement { + self.statement.as_mut() + } + + pub fn get_statement(&self) -> &CassandraStatement { + self.statement.as_ref() + } + + pub fn clone_statement(&self) -> CassandraStatement { + self.get_statement().clone() + } + fn from_value_and_col_spec(value: &Value, col_spec: &ColSpec) -> Operand { match value { Value::Some(vec) => { @@ -515,25 +527,23 @@ impl CQL { query_params: &QueryParams, param_types: &[ColSpec], ) -> Operand { - if let Some(values) = &query_params.values { - if let QueryValues::NamedValues(value_map) = values { - if let Some(value) = value_map.get(name) { - if let Some(idx) = value_map - .iter() - .enumerate() - .filter_map( - |(idx, (key, _value))| { - if key.eq(name) { - Some(idx) - } else { - None - } - }, - ) - .next() - { - return CQL::from_value_and_col_spec(value, ¶m_types[idx]); - } + if let Some(QueryValues::NamedValues(value_map)) = &query_params.values { + if let Some(value) = value_map.get(name) { + if let Some(idx) = value_map + .iter() + .enumerate() + .filter_map( + |(idx, (key, _value))| { + if key.eq(name) { + Some(idx) + } else { + None + } + }, + ) + .next() + { + return CQL::from_value_and_col_spec(value, ¶m_types[idx]); } } } @@ -561,7 +571,7 @@ impl CQL { fn set_operand_if_param( operand: &Operand, - mut param_idx: &mut usize, + param_idx: &mut usize, query_params: &QueryParams, param_types: &[ColSpec], ) -> Operand { @@ -571,7 +581,7 @@ impl CQL { vec.iter().for_each(|o| { vec2.push(CQL::set_operand_if_param( o, - &mut param_idx, + param_idx, query_params, param_types, )) @@ -580,7 +590,7 @@ impl CQL { Operand::Tuple(vec2) } Operand::Param(param_name) => { - if param_name.starts_with("?") { + if param_name.starts_with('?') { CQL::set_param_value_by_position(param_idx, query_params, param_types) } else { let name = param_name.split_at(0).1; @@ -592,7 +602,7 @@ impl CQL { vec.iter().for_each(|o| { vec2.push(CQL::set_operand_if_param( o, - &mut param_idx, + param_idx, query_params, param_types, )) @@ -610,9 +620,9 @@ impl CQL { param_types: &[ColSpec], where_clause: &mut [RelationElement], ) { - for relation_idx in 0..where_clause.len() { - where_clause[relation_idx].value = CQL::set_operand_if_param( - &where_clause[relation_idx].value, + for relation_element in where_clause { + relation_element.value = CQL::set_operand_if_param( + &relation_element.value, param_idx, query_params, param_types, @@ -629,7 +639,7 @@ impl CQL { param_types: &[ColSpec], ) -> CassandraStatement { let mut param_idx: usize = 0; - let mut statement = self.statement.clone(); + let mut statement = self.clone_statement(); match &mut statement { CassandraStatement::Delete(delete) => { CQL::set_relation_elements_values( @@ -647,13 +657,9 @@ impl CQL { } CassandraStatement::Insert(insert) => { if let InsertValues::Values(operands) = &mut insert.values { - for operand_idx in 0..operands.len() { - operands[operand_idx] = CQL::set_operand_if_param( - &mut operands[operand_idx], - &mut param_idx, - params, - param_types, - ) + for operand in operands { + *operand = + CQL::set_operand_if_param(operand, &mut param_idx, params, param_types) } } } @@ -679,7 +685,7 @@ impl CQL { AssignmentOperator::Plus(operand) => { assignment_element.operator = Option::from( AssignmentOperator::Plus(CQL::set_operand_if_param( - &operand, + operand, &mut param_idx, params, param_types, @@ -689,7 +695,7 @@ impl CQL { AssignmentOperator::Minus(operand) => { assignment_element.operator = Option::from( AssignmentOperator::Minus(CQL::set_operand_if_param( - &operand, + operand, &mut param_idx, params, param_types, @@ -754,8 +760,8 @@ impl CQL { } CassandraStatement::Insert(insert) => { if let InsertValues::Values(operands) = &insert.values { - for operand_idx in 0..operands.len() { - if let Operand::Param(_) = &operands[operand_idx] { + for operand in operands { + if let Operand::Param(_) = operand { return true; } } @@ -806,7 +812,7 @@ impl CQL { let ast = CassandraAST::new(cql_query_str); CQL { has_error: ast.has_error(), - statement: ast.statements.first().unwrap().clone(), + statement: Box::new(ast.statements.first().unwrap().clone()), } } diff --git a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs index bc801aaf3..00500be5d 100644 --- a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs @@ -72,8 +72,9 @@ fn extract_native_port_column(message: &mut Message) -> Vec { let mut result: Vec = vec![]; if let Some(Frame::Cassandra(cassandra)) = message.frame() { if let CassandraOperation::Query { query, .. } = &cassandra.operation { - if let CassandraStatement::Select(select) = &query.statement { - if let Some(table_name) = CQL::get_table_name(&query.statement) { + let statement = query.get_statement(); + if let CassandraStatement::Select(select) = &statement { + if let Some(table_name) = CQL::get_table_name(statement) { if table_name.eq("system.peers_v2") { select .columns diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index ec8377df2..5509cbede 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -238,17 +238,14 @@ impl Transform for Protect { .. })) = message.frame() { - if let Some(table_name) = CQL::get_table_name(&query.statement) { + let statement = query.get_statement_mut(); + if let Some(table_name) = CQL::get_table_name(statement) { if let Some((_, tables)) = self.keyspace_table_columns.get_key_value(table_name) { if let Some((_, columns)) = tables.get_key_value(table_name) { - data_changed = encrypt_columns( - &mut query.statement, - columns, - &self.key_source, - &self.key_id, - ) - .await?; + data_changed = + encrypt_columns(statement, columns, &self.key_source, &self.key_id) + .await?; } } } @@ -278,14 +275,15 @@ impl Transform for Protect { .. })) = request.frame() { - if let Some(table_name) = CQL::get_table_name(&query.statement) { + let statement = query.get_statement(); + if let Some(table_name) = CQL::get_table_name(statement) { if let Some((_keyspace, tables)) = self.keyspace_table_columns.get_key_value(table_name) { if let Some((_table, protect_columns)) = tables.get_key_value(table_name) { - if let CassandraStatement::Select(select) = &query.statement { + if let CassandraStatement::Select(select) = &statement { let positions: Vec = select .columns .iter() diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index 233497267..91f91c882 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -234,18 +234,15 @@ fn build_redis_ast_from_cql3( for relation_element in &select.where_clause { if let Operand::Column(column_name) = &relation_element.obj { // name has to be in partition or range key. - if table_cache_schema.partition_key.contains(&column_name) { - partition_segments.insert(&column_name, &relation_element.value); - } else if table_cache_schema.range_key.contains(&column_name) { + if table_cache_schema.partition_key.contains(column_name) { + partition_segments.insert(column_name, &relation_element.value); + } else if table_cache_schema.range_key.contains(column_name) { let value = range_segments.get_mut(column_name.as_str()); - let vec = if value.is_none() { - range_segments.insert(&column_name, vec![]); - range_segments.get_mut(column_name.as_str()).unwrap() + if let Some(vec) = value { + vec.push(relation_element) } else { - value.unwrap() + range_segments.insert(column_name, vec![relation_element]); }; - - vec.push(relation_element); } else { return Err(anyhow!( "Couldn't build query- column {} is not in the key", @@ -327,10 +324,10 @@ fn build_redis_ast_from_cql3( for relation_element in &update.where_clause { if relation_element.oper == RelationOperator::Equal { if let Operand::Column(name) = &relation_element.obj { - if table_cache_schema.partition_key.contains(&name) - || table_cache_schema.range_key.contains(&name) + if table_cache_schema.partition_key.contains(name) + || table_cache_schema.range_key.contains(name) { - query_values.insert(&name, &relation_element.value); + query_values.insert(name, &relation_element.value); } } } @@ -471,7 +468,7 @@ mod test { fn build_query(query_string: &str) -> CassandraStatement { let cql = CQL::parse_from_string(query_string); assert!(!cql.has_error); - cql.statement + cql.get_statement().clone() } #[test] From a835258343c060c1cd6681fbf091d0e57332d478 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Thu, 7 Apr 2022 10:53:28 +0100 Subject: [PATCH 11/60] removed invalid network 'name' attribute --- .../test-configs/cassandra-peers-rewrite/docker-compose.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/shotover-proxy/tests/test-configs/cassandra-peers-rewrite/docker-compose.yml b/shotover-proxy/tests/test-configs/cassandra-peers-rewrite/docker-compose.yml index 205dc3486..74cd1ee58 100644 --- a/shotover-proxy/tests/test-configs/cassandra-peers-rewrite/docker-compose.yml +++ b/shotover-proxy/tests/test-configs/cassandra-peers-rewrite/docker-compose.yml @@ -1,7 +1,6 @@ version: "3.3" networks: cluster_subnet: - name: cluster_subnet driver: bridge ipam: driver: default From de0c0d7592d491ff2587a4d855f25bdf6b7367b8 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Thu, 7 Apr 2022 14:09:30 +0100 Subject: [PATCH 12/60] updated for cache fix --- shotover-proxy/src/frame/cassandra.rs | 8 +- .../src/transforms/query_counter.rs | 11 +- shotover-proxy/src/transforms/redis/cache.rs | 306 +++++++++++------- 3 files changed, 203 insertions(+), 122 deletions(-) diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index a60480e1f..9b623c41e 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -358,12 +358,14 @@ impl CassandraOperation { /// An Err is returned if the operation cannot contain queries or the queries failed to parse. /// /// TODO: This will return a custom iterator type when BATCH support is added - pub fn queries(&mut self) -> Result> { + pub fn queries(&mut self) -> Vec<&mut CassandraStatement> { + let mut result = vec!(); match self { - CassandraOperation::Query { query: cql, .. } => Ok(std::iter::once(&mut cql.statement)), + CassandraOperation::Query { query: cql, .. } => result.push( &mut *cql.statement), // TODO: Return CassandraOperation::Batch queries once we add BATCH parsing to cassandra-protocol - _ => Err(anyhow!("This operation cannot contain queries")), + _ => { } } + result } fn to_direction(&self) -> Direction { diff --git a/shotover-proxy/src/transforms/query_counter.rs b/shotover-proxy/src/transforms/query_counter.rs index 4c674dac1..622e3c78b 100644 --- a/shotover-proxy/src/transforms/query_counter.rs +++ b/shotover-proxy/src/transforms/query_counter.rs @@ -30,15 +30,16 @@ impl Transform for QueryCounter { async fn transform<'a>(&'a mut self, mut message_wrapper: Wrapper<'a>) -> ChainResponse { for m in &mut message_wrapper.messages { match m.frame() { - Some(Frame::Cassandra(frame)) => match frame.operation.queries() { - Ok(queries) => { + + Some(Frame::Cassandra(frame)) => { + let queries = frame.operation.queries(); + if queries.is_empty() { + counter!("query_count", 1, "name" => self.counter_name.clone(), "query" => "unknown", "type" => "cassandra"); + } else { for statement in queries { counter!("query_count", 1, "name" => self.counter_name.clone(), "query" => statement.short_name(), "type" => "cassandra"); } } - Err(_) => { - counter!("query_count", 1, "name" => self.counter_name.clone(), "query" => "unknown", "type" => "cassandra"); - } }, Some(Frame::Redis(frame)) => { if let Some(query_type) = get_redis_query_type(frame) { diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index 91f91c882..b3310ffdf 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -6,7 +6,7 @@ use crate::transforms::chain::TransformChain; use crate::transforms::{ build_chain_from_config, Transform, Transforms, TransformsConfig, Wrapper, }; -use anyhow::{anyhow, bail, Result}; +use anyhow::Result; use async_trait::async_trait; use bytes::{BufMut, Bytes, BytesMut}; use cassandra_protocol::frame::Version; @@ -16,6 +16,17 @@ use cql3_parser::select::SelectElement; use itertools::Itertools; use serde::Deserialize; use std::collections::{BTreeMap, HashMap}; +use tracing_log::log::info; + +enum CacheableState { + Read, + Update, + Delete, + /// string is the reason for the skip + Skip(String), + /// string is the reason for the error + Err(String), +} #[derive(Deserialize, Debug, Clone)] pub struct RedisConfig { @@ -53,7 +64,7 @@ impl SimpleRedisCache { async fn get_or_update_from_cache( &mut self, mut messages_cass_request: Messages, - ) -> ChainResponse { + ) -> Result { // This function is a little hard to follow, so here's an overview. // We have 4 vecs of messages, each vec can be considered its own stage of processing. // 1. messages_cass_request: @@ -76,54 +87,83 @@ impl SimpleRedisCache { for cass_request in &mut messages_cass_request { match cass_request.frame() { Some(Frame::Cassandra(frame)) => { - for query in frame.operation.queries()? { - if let Some(table_name) = CQL::get_table_name(query) { - let table_cache_schema = self - .caching_schema - .get(table_name) - .ok_or_else(|| anyhow!("{table_name} not a caching table"))?; - - messages_redis_request.push(Message::from_frame(Frame::Redis( - build_redis_ast_from_cql3(query, table_cache_schema)?, - ))); + for statement in frame.operation.queries() { + let mut state = is_cacheable(statement); + + match state { + // TODO implement proper handling of state + // currently if state is not Skip or Error it just processes + CacheableState::Read | + CacheableState::Update | + CacheableState::Delete => { + if let Some(table_name) = CQL::get_table_name(statement) { + if let Some(table_cache_schema) = self + .caching_schema + .get(table_name) { + let redis_state = build_redis_ast_from_cql3(statement, table_cache_schema); + if redis_state.is_ok() { + messages_redis_request.push(Message::from_frame(Frame::Redis(redis_state.ok().unwrap()))); + } else { + state = redis_state.err().unwrap(); + } + } else { + state = CacheableState::Skip("Table not in caching list".into()); + } + } else { + state = CacheableState::Skip("Table name not in query".into()); + } + } + _ => { + // do nothing here but check again again outside of match as state may have changed + } + } + + if let CacheableState::Err(_) = state { + return Err(state); + } + + if let CacheableState::Skip(_) = state { + return Err(state); } } } - message => bail!("cannot fetch {message:?} from cache"), + _ => { return Err( CacheableState::Err( format!("cannot fetch {cass_request:?} from cache")));}, } } - let messages_redis_response = self + match self .cache_chain .process_request( Wrapper::new_with_chain_name(messages_redis_request, self.cache_chain.name.clone()), "clientdetailstodo".to_string(), ) - .await?; - - // Replace cass_request messages with cassandra responses in place. - // We reuse the vec like this to save allocations. - let mut messages_redis_response_iter = messages_redis_response.into_iter(); - for cass_request in &mut messages_cass_request { - let mut redis_responses = vec![]; - if let Some(Frame::Cassandra(frame)) = cass_request.frame() { - if let Ok(queries) = frame.operation.queries() { - for _query in queries { - redis_responses.push(messages_redis_response_iter.next()); + .await { + Ok(messages_redis_response) => { + // Replace cass_request messages with cassandra responses in place. + // We reuse the vec like this to save allocations. + let mut messages_redis_response_iter = messages_redis_response.into_iter(); + for cass_request in &mut messages_cass_request { + let mut redis_responses = vec![]; + if let Some(Frame::Cassandra(frame)) = cass_request.frame() { + for _query in frame.operation.queries() { + redis_responses.push(messages_redis_response_iter.next()); + } } + + // TODO: Translate the redis_responses into a cassandra result + *cass_request = Message::from_frame(Frame::Cassandra(CassandraFrame { + version: Version::V4, + operation: CassandraOperation::Result(CassandraResult::Void), + stream_id: cass_request.stream_id().unwrap(), + tracing_id: None, + warnings: vec![], + })); } - } + Ok(messages_cass_request) - // TODO: Translate the redis_responses into a cassandra result - *cass_request = Message::from_frame(Frame::Cassandra(CassandraFrame { - version: Version::V4, - operation: CassandraOperation::Result(CassandraResult::Void), - stream_id: cass_request.stream_id().unwrap(), - tracing_id: None, - warnings: vec![], - })); + } + Err(e) => Err( CacheableState::Err( format!("Redis error: {}", e ))), } - Ok(messages_cass_request) } } @@ -148,7 +188,7 @@ fn build_zrangebylex_min_max_from_sql( operand: &Operand, min: &mut Vec, max: &mut Vec, -) -> Result<()> { +) -> Result<(),CacheableState> { let mut bytes = BytesMut::from(operand.to_string().as_bytes()); match operator { RelationOperator::LessThan => { @@ -157,10 +197,12 @@ fn build_zrangebylex_min_max_from_sql( append_prefix_max(max); max.extend(bytes.iter()); + Ok(()) } RelationOperator::LessThanOrEqual => { append_prefix_max(max); max.extend(bytes.iter()); + Ok(()) } RelationOperator::Equal => { @@ -168,59 +210,77 @@ fn build_zrangebylex_min_max_from_sql( append_prefix_max(max); min.extend(bytes.iter()); max.extend(bytes.iter()); + Ok(()) } RelationOperator::GreaterThanOrEqual => { append_prefix_min(min); min.extend(bytes.iter()); + Ok(()) } RelationOperator::GreaterThan => { let last_byte = bytes.last_mut().unwrap(); *last_byte += 1; append_prefix_min(min); min.extend(bytes.iter()); + Ok(()) } // should "IN"" be converted to an "or" "eq" combination RelationOperator::NotEqual | RelationOperator::In | RelationOperator::Contains - | RelationOperator::ContainsKey - | RelationOperator::IsNot => { - return Err(anyhow!("Couldn't build query")); - } + | RelationOperator::ContainsKey => Err( CacheableState::Skip( format!( "{} comparisons are not supported", operator ))), + RelationOperator::IsNot => Err( CacheableState::Skip( format!( "IS NOT NULL comparisons are not supported" ))), } - Ok(()) } -fn is_cacheable(statement: &CassandraStatement) -> bool { - CQL::has_params(statement) - || match statement { - CassandraStatement::Select(select) => { - if select.filtering || select.where_clause.is_empty() { - return false; - } - if !select.columns.is_empty() { - if select.columns.len() == 1 && select.columns[0].eq(&SelectElement::Star) { - // ok - } else { - return false; - } - } - true - } - CassandraStatement::Insert(insert) => !insert.if_not_exists, - CassandraStatement::Update(update) => !update.if_exists, - - _ => false, - } +fn is_cacheable(statement: &CassandraStatement) -> CacheableState { + let has_params = CQL::has_params(statement); + + match statement { + CassandraStatement::Select(select) => + if has_params { + CacheableState::Delete + } else if select.filtering { + CacheableState::Skip("Can not cache with ALLOW FILTERING".into()) + } else if select.where_clause.is_empty() { + CacheableState::Skip("Can not cache if where clause is empty".into()) + } else if !select.columns.is_empty() { + if select.columns.len() == 1 && select.columns[0].eq(&SelectElement::Star) { + CacheableState::Read + } else { + CacheableState::Skip("Can not cache if columns other than '*' are not selected".into()) + } + } else { + CacheableState::Read + }, + CassandraStatement::Insert(insert) => if has_params || insert.if_not_exists { CacheableState::Delete } else { CacheableState::Update }, + CassandraStatement::Update(update) => { + if has_params || update.if_exists { + CacheableState::Delete + } else { + for assignment_element in &update.assignments { + if assignment_element.operator.is_some() { + info!("Clearing {} cache: {} has calculations in values", update.table_name, assignment_element.name); + return CacheableState::Delete; + } + if assignment_element.name.idx.is_some() { + info!("Clearing {} cache: {} is an indexed columns", update.table_name, assignment_element.name); + return CacheableState::Delete; + } + } + CacheableState::Update + } + }, + + _ => CacheableState::Skip("Statement is not a cacheable type".into()), + } } fn build_redis_ast_from_cql3( statement: &CassandraStatement, table_cache_schema: &TableCacheSchema, -) -> Result { - if !is_cacheable(statement) { - return Err(anyhow!("{} is not cacheable", statement)); - } +) -> Result { + match statement { CassandraStatement::Select(select) => { let mut min: Vec = Vec::new(); @@ -244,10 +304,10 @@ fn build_redis_ast_from_cql3( range_segments.insert(column_name, vec![relation_element]); }; } else { - return Err(anyhow!( - "Couldn't build query- column {} is not in the key", + return Err( CacheableState::Skip( format!( + "Couldn't build query - column {} is not in the key", column_name - )); + ))); } } } @@ -256,9 +316,7 @@ fn build_redis_ast_from_cql3( if let Some(relation_elements) = range_segments.get(column_name.as_str()) { if skipping { // we skipped an earlier column so this is an error. - return Err(anyhow!( - "Columns in the middle of the range key were skipped" - )); + return Err(CacheableState::Err( "Columns in the middle of the range key were skipped".into() )); } for range_element in relation_elements { if let Err(e) = build_zrangebylex_min_max_from_sql( @@ -291,7 +349,7 @@ fn build_redis_ast_from_cql3( if let Some(operand) = partition_segments.get(column_name.as_str()) { partition_key.extend(operand.to_string().as_bytes()); } else { - return Err(anyhow!("partition column {} missing", column_name)); + return Err(CacheableState::Err(format!("partition column {} missing", column_name))); } } @@ -309,18 +367,6 @@ fn build_redis_ast_from_cql3( } CassandraStatement::Update(update) => { let mut query_values: BTreeMap<&str, &Operand> = BTreeMap::new(); - for assignment_element in &update.assignments { - if assignment_element.operator.is_some() { - return Err(anyhow!("Update has calculations in values")); - } - if assignment_element.name.idx.is_some() { - return Err(anyhow!("Update has indexed columns")); - } - query_values.insert( - assignment_element.name.column.as_str(), - &assignment_element.value, - ); - } for relation_element in &update.where_clause { if relation_element.oper == RelationOperator::Equal { if let Operand::Column(name) = &relation_element.obj { @@ -334,20 +380,20 @@ fn build_redis_ast_from_cql3( } add_query_values(table_cache_schema, query_values) } - statement => Err(anyhow!("Cant build query from statement: {}", statement)), + _ => unreachable!( "{} should not be passed to build_redis_ast_from_cql3", statement ), } } fn add_query_values( table_cache_schema: &TableCacheSchema, query_values: BTreeMap<&str, &Operand>, -) -> Result { +) -> Result { let mut partition_key = BytesMut::new(); for column_name in &table_cache_schema.partition_key { if let Some(operand) = query_values.get(column_name.as_str()) { partition_key.extend(operand.to_string().as_bytes()); } else { - return Err(anyhow!("partition column {} missing", column_name)); + return Err(CacheableState::Err(format!("partition column {} missing", column_name))); } } @@ -356,7 +402,7 @@ fn add_query_values( if let Some(operand) = query_values.get(column_name.as_str()) { clustering.extend(operand.to_string().as_bytes()); } else { - return Err(anyhow!("range column {} missing", column_name)); + return Err(CacheableState::Err(format!("range column {} missing", column_name))); } } @@ -399,40 +445,72 @@ fn add_query_values( #[async_trait] impl Transform for SimpleRedisCache { async fn transform<'a>(&'a mut self, mut message_wrapper: Wrapper<'a>) -> ChainResponse { - let mut updates = false; - + let mut read_cache = true; for m in &mut message_wrapper.messages { if let Some(Frame::Cassandra(CassandraFrame { operation: CassandraOperation::Query { .. }, .. })) = m.frame() { - if m.get_query_type() == QueryType::Write { - updates = true; - break; + /* let statement = query.get_statement(); + match is_cacheable(statement ) { + CacheableState::Read | + CacheableState::Update | + CacheableState::Delete => {} + CacheableState::Skip(reason) => { + tracing::info!( "Cache skipped for {} due to {}", statement, reason ); + use_cache = false; + } + CacheableState::Err(reason) => { + tracing::error!("Cache failed for {} due to {}", statement, reason); + use_cache = false; + } + } + + */ + match m.get_query_type() { + QueryType::Read => {} + QueryType::Write => { read_cache =false} + QueryType::ReadWrite => { read_cache =false} + QueryType::SchemaChange => { read_cache =false} + QueryType::PubSubMessage => {} } } } // If there are no write queries (all queries are reads) we can use the cache - if !updates { + if read_cache { match self .get_or_update_from_cache(message_wrapper.messages.clone()) .await { - Ok(cr) => Ok(cr), - Err(e) => { - tracing::error!("failed to fetch from cache: {:?}", e); - message_wrapper.call_next_transform().await + Ok(cr) => return Ok(cr), + Err(inner_state) => { + match &inner_state { + CacheableState::Read | + CacheableState::Update | + CacheableState::Delete => { + unreachable!("should not find read, update or delete as an error"); + } + CacheableState::Skip(reason) => { + tracing::info!("Cache skipped: {} ", reason); + message_wrapper.call_next_transform().await + } + CacheableState::Err(reason) => { + tracing::error!("Cache failed: {} ", reason); + message_wrapper.call_next_transform().await + } + } } } } else { let (_cache_res, upstream) = tokio::join!( - self.get_or_update_from_cache(message_wrapper.messages.clone()), - message_wrapper.call_next_transform() - ); + self.get_or_update_from_cache(message_wrapper.messages.clone()), + message_wrapper.call_next_transform() + ); upstream } + } fn validate(&self) -> Vec { @@ -480,7 +558,7 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x = 123 AND y = 965"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), @@ -501,7 +579,7 @@ mod test { let ast = build_query("INSERT INTO foo (z, v) VALUES (1, 123)"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZADD")), @@ -521,7 +599,7 @@ mod test { }; let ast = build_query("INSERT INTO foo (z, c, v) VALUES (1, 'yo' , 123)"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZADD")), @@ -542,7 +620,7 @@ mod test { let ast = build_query("UPDATE foo SET c = 'yo', v = 123 WHERE z = 1"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZADD")), @@ -563,11 +641,11 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x = 123 AND y = 965"); - let query_one = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); + let query_one = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); let ast = build_query("SELECT * FROM foo WHERE y = 965 AND z = 1 AND x = 123"); - let query_two = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); + let query_two = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); // Semantically databases treat the order of AND clauses differently, Cassandra however requires clustering key predicates be in order // So here we will just expect the order is correct in the query. TODO: we may need to revisit this as support for other databases is added @@ -583,7 +661,7 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x > 123 AND x < 999"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), @@ -604,7 +682,7 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x >= 123 AND x <= 999"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), @@ -625,7 +703,7 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), @@ -646,7 +724,7 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND y = 2"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), @@ -667,7 +745,7 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x >= 123"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), @@ -680,7 +758,7 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x <= 123"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), From 9ab220aa947efe82b1f8290103f062555641b65c Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Thu, 7 Apr 2022 14:18:55 +0100 Subject: [PATCH 13/60] fixed clippy issues --- shotover-proxy/src/frame/cassandra.rs | 7 +- .../src/transforms/query_counter.rs | 3 +- shotover-proxy/src/transforms/redis/cache.rs | 263 +++++++++++------- 3 files changed, 170 insertions(+), 103 deletions(-) diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index 9b623c41e..f11e2d271 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -359,12 +359,17 @@ impl CassandraOperation { /// /// TODO: This will return a custom iterator type when BATCH support is added pub fn queries(&mut self) -> Vec<&mut CassandraStatement> { - let mut result = vec!(); + let mut result = vec![]; + /* match self { CassandraOperation::Query { query: cql, .. } => result.push( &mut *cql.statement), // TODO: Return CassandraOperation::Batch queries once we add BATCH parsing to cassandra-protocol _ => { } } + */ + if let CassandraOperation::Query { query: cql, .. } = self { + result.push(&mut *cql.statement) + }; result } diff --git a/shotover-proxy/src/transforms/query_counter.rs b/shotover-proxy/src/transforms/query_counter.rs index 622e3c78b..ce10f2e71 100644 --- a/shotover-proxy/src/transforms/query_counter.rs +++ b/shotover-proxy/src/transforms/query_counter.rs @@ -30,7 +30,6 @@ impl Transform for QueryCounter { async fn transform<'a>(&'a mut self, mut message_wrapper: Wrapper<'a>) -> ChainResponse { for m in &mut message_wrapper.messages { match m.frame() { - Some(Frame::Cassandra(frame)) => { let queries = frame.operation.queries(); if queries.is_empty() { @@ -40,7 +39,7 @@ impl Transform for QueryCounter { counter!("query_count", 1, "name" => self.counter_name.clone(), "query" => statement.short_name(), "type" => "cassandra"); } } - }, + } Some(Frame::Redis(frame)) => { if let Some(query_type) = get_redis_query_type(frame) { counter!("query_count", 1, "name" => self.counter_name.clone(), "query" => query_type, "type" => "redis"); diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index b3310ffdf..168e91ab3 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -64,7 +64,7 @@ impl SimpleRedisCache { async fn get_or_update_from_cache( &mut self, mut messages_cass_request: Messages, - ) -> Result { + ) -> Result { // This function is a little hard to follow, so here's an overview. // We have 4 vecs of messages, each vec can be considered its own stage of processing. // 1. messages_cass_request: @@ -93,21 +93,28 @@ impl SimpleRedisCache { match state { // TODO implement proper handling of state // currently if state is not Skip or Error it just processes - CacheableState::Read | - CacheableState::Update | - CacheableState::Delete => { + CacheableState::Read + | CacheableState::Update + | CacheableState::Delete => { if let Some(table_name) = CQL::get_table_name(statement) { - if let Some(table_cache_schema) = self - .caching_schema - .get(table_name) { - let redis_state = build_redis_ast_from_cql3(statement, table_cache_schema); + if let Some(table_cache_schema) = + self.caching_schema.get(table_name) + { + let redis_state = build_redis_ast_from_cql3( + statement, + table_cache_schema, + ); if redis_state.is_ok() { - messages_redis_request.push(Message::from_frame(Frame::Redis(redis_state.ok().unwrap()))); + messages_redis_request.push(Message::from_frame( + Frame::Redis(redis_state.ok().unwrap()), + )); } else { state = redis_state.err().unwrap(); } } else { - state = CacheableState::Skip("Table not in caching list".into()); + state = CacheableState::Skip( + "Table not in caching list".into(), + ); } } else { state = CacheableState::Skip("Table name not in query".into()); @@ -127,7 +134,11 @@ impl SimpleRedisCache { } } } - _ => { return Err( CacheableState::Err( format!("cannot fetch {cass_request:?} from cache")));}, + _ => { + return Err(CacheableState::Err(format!( + "cannot fetch {cass_request:?} from cache" + ))); + } } } @@ -137,7 +148,8 @@ impl SimpleRedisCache { Wrapper::new_with_chain_name(messages_redis_request, self.cache_chain.name.clone()), "clientdetailstodo".to_string(), ) - .await { + .await + { Ok(messages_redis_response) => { // Replace cass_request messages with cassandra responses in place. // We reuse the vec like this to save allocations. @@ -145,7 +157,7 @@ impl SimpleRedisCache { for cass_request in &mut messages_cass_request { let mut redis_responses = vec![]; if let Some(Frame::Cassandra(frame)) = cass_request.frame() { - for _query in frame.operation.queries() { + for _query in frame.operation.queries() { redis_responses.push(messages_redis_response_iter.next()); } } @@ -160,9 +172,8 @@ impl SimpleRedisCache { })); } Ok(messages_cass_request) - } - Err(e) => Err( CacheableState::Err( format!("Redis error: {}", e ))), + Err(e) => Err(CacheableState::Err(format!("Redis error: {}", e))), } } } @@ -188,7 +199,7 @@ fn build_zrangebylex_min_max_from_sql( operand: &Operand, min: &mut Vec, max: &mut Vec, -) -> Result<(),CacheableState> { +) -> Result<(), CacheableState> { let mut bytes = BytesMut::from(operand.to_string().as_bytes()); match operator { RelationOperator::LessThan => { @@ -228,59 +239,78 @@ fn build_zrangebylex_min_max_from_sql( RelationOperator::NotEqual | RelationOperator::In | RelationOperator::Contains - | RelationOperator::ContainsKey => Err( CacheableState::Skip( format!( "{} comparisons are not supported", operator ))), - RelationOperator::IsNot => Err( CacheableState::Skip( format!( "IS NOT NULL comparisons are not supported" ))), + | RelationOperator::ContainsKey => Err(CacheableState::Skip(format!( + "{} comparisons are not supported", + operator + ))), + RelationOperator::IsNot => Err(CacheableState::Skip( + "IS NOT NULL comparisons are not supported".into(), + )), } } fn is_cacheable(statement: &CassandraStatement) -> CacheableState { let has_params = CQL::has_params(statement); - match statement { - CassandraStatement::Select(select) => - if has_params { - CacheableState::Delete - } else if select.filtering { - CacheableState::Skip("Can not cache with ALLOW FILTERING".into()) - } else if select.where_clause.is_empty() { - CacheableState::Skip("Can not cache if where clause is empty".into()) - } else if !select.columns.is_empty() { - if select.columns.len() == 1 && select.columns[0].eq(&SelectElement::Star) { - CacheableState::Read - } else { - CacheableState::Skip("Can not cache if columns other than '*' are not selected".into()) - } - } else { - CacheableState::Read - }, - CassandraStatement::Insert(insert) => if has_params || insert.if_not_exists { CacheableState::Delete } else { CacheableState::Update }, - CassandraStatement::Update(update) => { - if has_params || update.if_exists { - CacheableState::Delete - } else { - for assignment_element in &update.assignments { - if assignment_element.operator.is_some() { - info!("Clearing {} cache: {} has calculations in values", update.table_name, assignment_element.name); - return CacheableState::Delete; - } - if assignment_element.name.idx.is_some() { - info!("Clearing {} cache: {} is an indexed columns", update.table_name, assignment_element.name); - return CacheableState::Delete; - } - } - CacheableState::Update - } - }, - - _ => CacheableState::Skip("Statement is not a cacheable type".into()), - } + match statement { + CassandraStatement::Select(select) => { + if has_params { + CacheableState::Delete + } else if select.filtering { + CacheableState::Skip("Can not cache with ALLOW FILTERING".into()) + } else if select.where_clause.is_empty() { + CacheableState::Skip("Can not cache if where clause is empty".into()) + } else if !select.columns.is_empty() { + if select.columns.len() == 1 && select.columns[0].eq(&SelectElement::Star) { + CacheableState::Read + } else { + CacheableState::Skip( + "Can not cache if columns other than '*' are not selected".into(), + ) + } + } else { + CacheableState::Read + } + } + CassandraStatement::Insert(insert) => { + if has_params || insert.if_not_exists { + CacheableState::Delete + } else { + CacheableState::Update + } + } + CassandraStatement::Update(update) => { + if has_params || update.if_exists { + CacheableState::Delete + } else { + for assignment_element in &update.assignments { + if assignment_element.operator.is_some() { + info!( + "Clearing {} cache: {} has calculations in values", + update.table_name, assignment_element.name + ); + return CacheableState::Delete; + } + if assignment_element.name.idx.is_some() { + info!( + "Clearing {} cache: {} is an indexed columns", + update.table_name, assignment_element.name + ); + return CacheableState::Delete; + } + } + CacheableState::Update + } + } + + _ => CacheableState::Skip("Statement is not a cacheable type".into()), + } } fn build_redis_ast_from_cql3( statement: &CassandraStatement, table_cache_schema: &TableCacheSchema, ) -> Result { - match statement { CassandraStatement::Select(select) => { let mut min: Vec = Vec::new(); @@ -304,7 +334,7 @@ fn build_redis_ast_from_cql3( range_segments.insert(column_name, vec![relation_element]); }; } else { - return Err( CacheableState::Skip( format!( + return Err(CacheableState::Skip(format!( "Couldn't build query - column {} is not in the key", column_name ))); @@ -316,7 +346,9 @@ fn build_redis_ast_from_cql3( if let Some(relation_elements) = range_segments.get(column_name.as_str()) { if skipping { // we skipped an earlier column so this is an error. - return Err(CacheableState::Err( "Columns in the middle of the range key were skipped".into() )); + return Err(CacheableState::Err( + "Columns in the middle of the range key were skipped".into(), + )); } for range_element in relation_elements { if let Err(e) = build_zrangebylex_min_max_from_sql( @@ -349,7 +381,10 @@ fn build_redis_ast_from_cql3( if let Some(operand) = partition_segments.get(column_name.as_str()) { partition_key.extend(operand.to_string().as_bytes()); } else { - return Err(CacheableState::Err(format!("partition column {} missing", column_name))); + return Err(CacheableState::Err(format!( + "partition column {} missing", + column_name + ))); } } @@ -380,20 +415,26 @@ fn build_redis_ast_from_cql3( } add_query_values(table_cache_schema, query_values) } - _ => unreachable!( "{} should not be passed to build_redis_ast_from_cql3", statement ), + _ => unreachable!( + "{} should not be passed to build_redis_ast_from_cql3", + statement + ), } } fn add_query_values( table_cache_schema: &TableCacheSchema, query_values: BTreeMap<&str, &Operand>, -) -> Result { +) -> Result { let mut partition_key = BytesMut::new(); for column_name in &table_cache_schema.partition_key { if let Some(operand) = query_values.get(column_name.as_str()) { partition_key.extend(operand.to_string().as_bytes()); } else { - return Err(CacheableState::Err(format!("partition column {} missing", column_name))); + return Err(CacheableState::Err(format!( + "partition column {} missing", + column_name + ))); } } @@ -402,7 +443,10 @@ fn add_query_values( if let Some(operand) = query_values.get(column_name.as_str()) { clustering.extend(operand.to_string().as_bytes()); } else { - return Err(CacheableState::Err(format!("range column {} missing", column_name))); + return Err(CacheableState::Err(format!( + "range column {} missing", + column_name + ))); } } @@ -452,7 +496,7 @@ impl Transform for SimpleRedisCache { .. })) = m.frame() { - /* let statement = query.get_statement(); + /* let statement = query.get_statement(); match is_cacheable(statement ) { CacheableState::Read | CacheableState::Update | @@ -470,9 +514,9 @@ impl Transform for SimpleRedisCache { */ match m.get_query_type() { QueryType::Read => {} - QueryType::Write => { read_cache =false} - QueryType::ReadWrite => { read_cache =false} - QueryType::SchemaChange => { read_cache =false} + QueryType::Write => read_cache = false, + QueryType::ReadWrite => read_cache = false, + QueryType::SchemaChange => read_cache = false, QueryType::PubSubMessage => {} } } @@ -485,32 +529,27 @@ impl Transform for SimpleRedisCache { .await { Ok(cr) => return Ok(cr), - Err(inner_state) => { - match &inner_state { - CacheableState::Read | - CacheableState::Update | - CacheableState::Delete => { - unreachable!("should not find read, update or delete as an error"); - } - CacheableState::Skip(reason) => { - tracing::info!("Cache skipped: {} ", reason); - message_wrapper.call_next_transform().await - } - CacheableState::Err(reason) => { - tracing::error!("Cache failed: {} ", reason); - message_wrapper.call_next_transform().await - } + Err(inner_state) => match &inner_state { + CacheableState::Read | CacheableState::Update | CacheableState::Delete => { + unreachable!("should not find read, update or delete as an error"); } - } + CacheableState::Skip(reason) => { + tracing::info!("Cache skipped: {} ", reason); + message_wrapper.call_next_transform().await + } + CacheableState::Err(reason) => { + tracing::error!("Cache failed: {} ", reason); + message_wrapper.call_next_transform().await + } + }, } } else { let (_cache_res, upstream) = tokio::join!( - self.get_or_update_from_cache(message_wrapper.messages.clone()), - message_wrapper.call_next_transform() - ); + self.get_or_update_from_cache(message_wrapper.messages.clone()), + message_wrapper.call_next_transform() + ); upstream } - } fn validate(&self) -> Vec { @@ -558,7 +597,9 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x = 123 AND y = 965"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema) + .ok() + .unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), @@ -579,7 +620,9 @@ mod test { let ast = build_query("INSERT INTO foo (z, v) VALUES (1, 123)"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema) + .ok() + .unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZADD")), @@ -599,7 +642,9 @@ mod test { }; let ast = build_query("INSERT INTO foo (z, c, v) VALUES (1, 'yo' , 123)"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema) + .ok() + .unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZADD")), @@ -620,7 +665,9 @@ mod test { let ast = build_query("UPDATE foo SET c = 'yo', v = 123 WHERE z = 1"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema) + .ok() + .unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZADD")), @@ -641,11 +688,15 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x = 123 AND y = 965"); - let query_one = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); + let query_one = build_redis_ast_from_cql3(&ast, &table_cache_schema) + .ok() + .unwrap(); let ast = build_query("SELECT * FROM foo WHERE y = 965 AND z = 1 AND x = 123"); - let query_two = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); + let query_two = build_redis_ast_from_cql3(&ast, &table_cache_schema) + .ok() + .unwrap(); // Semantically databases treat the order of AND clauses differently, Cassandra however requires clustering key predicates be in order // So here we will just expect the order is correct in the query. TODO: we may need to revisit this as support for other databases is added @@ -661,7 +712,9 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x > 123 AND x < 999"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema) + .ok() + .unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), @@ -682,7 +735,9 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x >= 123 AND x <= 999"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema) + .ok() + .unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), @@ -703,7 +758,9 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema) + .ok() + .unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), @@ -724,7 +781,9 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND y = 2"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema) + .ok() + .unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), @@ -745,7 +804,9 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x >= 123"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema) + .ok() + .unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), @@ -758,7 +819,9 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x <= 123"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema).ok().unwrap(); + let query = build_redis_ast_from_cql3(&ast, &table_cache_schema) + .ok() + .unwrap(); let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), From fe8ac42d5864aa7ee76abbfac41850a3f988cf02 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Thu, 7 Apr 2022 14:57:20 +0100 Subject: [PATCH 14/60] Fixed cache bug --- Cargo.lock | 24 ++++++++++---------- shotover-proxy/src/transforms/redis/cache.rs | 16 +++++++------ 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7d2c9885f..55da1e6ea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -291,7 +291,7 @@ checksum = "8234d29d30873ab5a41e3557b8515d3ecbaefb1ea5be579425b3b0074b6d0e40" [[package]] name = "cassandra-protocol" version = "1.1.0" -source = "git+https://github.com/krojew/cdrs-tokio#1946d5f7025f168aaea7b9ab27eda4708a0d8ba6" +source = "git+https://github.com/krojew/cdrs-tokio#61afd7f9fa897635a78a15a4e3e89864d16101f6" dependencies = [ "arrayref", "bitflags", @@ -438,7 +438,7 @@ dependencies = [ [[package]] name = "cql3_parser" version = "0.0.1" -source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#4b04fdc52d1aaea01dfc366e6d4382ad93c65f8e" +source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#b2820a666e9e163e511a803fcf884ed2e4e60625" dependencies = [ "bigdecimal", "bytes", @@ -581,9 +581,9 @@ dependencies = [ [[package]] name = "darling" -version = "0.13.2" +version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e92cb285610dd935f60ee8b4d62dd1988bd12b7ea50579bd6a138201525318e" +checksum = "a01d95850c592940db9b8194bc39f4bc0e89dee5c4265e4b1807c34a9aba453c" dependencies = [ "darling_core", "darling_macro", @@ -591,9 +591,9 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.13.2" +version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c29e95ab498b18131ea460b2c0baa18cbf041231d122b0b7bfebef8c8e88989" +checksum = "859d65a907b6852c9361e3185c862aae7fafd2887876799fa55f5f99dc40d610" dependencies = [ "fnv", "ident_case", @@ -605,9 +605,9 @@ dependencies = [ [[package]] name = "darling_macro" -version = "0.13.2" +version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b21dd6b221dd547528bd6fb15f1a3b7ab03b9a06f76bff288a8c629bcfbe7f0e" +checksum = "9c972679f83bdf9c42bd905396b6c3588a843a17f0f16dfcfa3e2c5d57441835" dependencies = [ "darling_core", "quote", @@ -1203,9 +1203,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.121" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efaa7b300f3b5fe8eb6bf21ce3895e1751d9665086af2d64b42f19701015ff4f" +checksum = "ec647867e2bf0772e28c8bcde4f0d19a9216916e890543b5a03ed8ef27b8f259" [[package]] name = "libloading" @@ -2701,9 +2701,9 @@ checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" [[package]] name = "syn" -version = "1.0.90" +version = "1.0.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "704df27628939572cd88d33f171cd6f896f4eaca85252c6e0a72d8d8287ee86f" +checksum = "b683b2b825c8eef438b77c36a06dc262294da3d5a5813fac20da149241dcd44d" dependencies = [ "proc-macro2", "quote", diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index 168e91ab3..827ee6dd9 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -401,14 +401,16 @@ fn build_redis_ast_from_cql3( add_query_values(table_cache_schema, query_values) } CassandraStatement::Update(update) => { - let mut query_values: BTreeMap<&str, &Operand> = BTreeMap::new(); + let mut query_values: BTreeMap = BTreeMap::new(); + + update.assignments.iter().for_each( |assignment| {query_values.insert( assignment.name.to_string(), &assignment.value);} ); for relation_element in &update.where_clause { if relation_element.oper == RelationOperator::Equal { if let Operand::Column(name) = &relation_element.obj { if table_cache_schema.partition_key.contains(name) || table_cache_schema.range_key.contains(name) { - query_values.insert(name, &relation_element.value); + query_values.insert(name.clone(), &relation_element.value); } } } @@ -424,11 +426,11 @@ fn build_redis_ast_from_cql3( fn add_query_values( table_cache_schema: &TableCacheSchema, - query_values: BTreeMap<&str, &Operand>, + query_values: BTreeMap, ) -> Result { let mut partition_key = BytesMut::new(); for column_name in &table_cache_schema.partition_key { - if let Some(operand) = query_values.get(column_name.as_str()) { + if let Some(operand) = query_values.get(column_name) { partition_key.extend(operand.to_string().as_bytes()); } else { return Err(CacheableState::Err(format!( @@ -665,9 +667,9 @@ mod test { let ast = build_query("UPDATE foo SET c = 'yo', v = 123 WHERE z = 1"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema) - .ok() - .unwrap(); + let result = build_redis_ast_from_cql3(&ast, &table_cache_schema); + let query = result.ok().unwrap(); + let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZADD")), From 4328a9aeeaaf386b473c87d2ca0b170d5a167a79 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Thu, 7 Apr 2022 15:16:35 +0100 Subject: [PATCH 15/60] fixed formatting issue --- shotover-proxy/src/transforms/redis/cache.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index 827ee6dd9..bcf54e4cb 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -403,7 +403,9 @@ fn build_redis_ast_from_cql3( CassandraStatement::Update(update) => { let mut query_values: BTreeMap = BTreeMap::new(); - update.assignments.iter().for_each( |assignment| {query_values.insert( assignment.name.to_string(), &assignment.value);} ); + update.assignments.iter().for_each(|assignment| { + query_values.insert(assignment.name.to_string(), &assignment.value); + }); for relation_element in &update.where_clause { if relation_element.oper == RelationOperator::Equal { if let Operand::Column(name) = &relation_element.obj { @@ -670,7 +672,6 @@ mod test { let result = build_redis_ast_from_cql3(&ast, &table_cache_schema); let query = result.ok().unwrap(); - let expected = RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"ZADD")), RedisFrame::BulkString(Bytes::from_static(b"1")), From a6ce802c9374b53dfa4f16dc6c6ddec4a84f7a30 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Fri, 8 Apr 2022 12:58:18 +0100 Subject: [PATCH 16/60] changes for debugging on linux --- .cargo/config.toml | 2 +- shotover-proxy/tests/helpers/cassandra.rs | 8 +++++++- .../cassandra-peers-rewrite/docker-compose.yml | 1 + 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/.cargo/config.toml b/.cargo/config.toml index 58103f6e0..15d5d329f 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -2,4 +2,4 @@ rustflags = [ "-C", "link-arg=-undefined", "-C", "link-arg=dynamic_lookup", -] \ No newline at end of file +] diff --git a/shotover-proxy/tests/helpers/cassandra.rs b/shotover-proxy/tests/helpers/cassandra.rs index 03c0d32ca..0ac362a92 100644 --- a/shotover-proxy/tests/helpers/cassandra.rs +++ b/shotover-proxy/tests/helpers/cassandra.rs @@ -1,5 +1,6 @@ use cassandra_cpp::{stmt, Cluster, Error, Session, Value, ValueType}; use ordered_float::OrderedFloat; +use test_helpers::try_wait_for_socket_to_open; pub fn cassandra_connection(contact_points: &str, port: u16) -> Session { for contact_point in contact_points.split(',') { @@ -10,7 +11,12 @@ pub fn cassandra_connection(contact_points: &str, port: u16) -> Session { cluster.set_credentials("cassandra", "cassandra").unwrap(); cluster.set_port(port).ok(); cluster.set_load_balance_round_robin(); - cluster.connect().unwrap() + let result = cluster.connect(); + if let Some(err) = &result.as_ref().err() { + assert!( false, "{}",err ); + } + result.unwrap() + //cluster.connect().unwrap() } #[derive(Debug, Clone, PartialOrd, PartialEq, Eq, Ord)] diff --git a/shotover-proxy/tests/test-configs/cassandra-peers-rewrite/docker-compose.yml b/shotover-proxy/tests/test-configs/cassandra-peers-rewrite/docker-compose.yml index 74cd1ee58..205dc3486 100644 --- a/shotover-proxy/tests/test-configs/cassandra-peers-rewrite/docker-compose.yml +++ b/shotover-proxy/tests/test-configs/cassandra-peers-rewrite/docker-compose.yml @@ -1,6 +1,7 @@ version: "3.3" networks: cluster_subnet: + name: cluster_subnet driver: bridge ipam: driver: default From bf224f3e01d573dd86c2f55c0e05b86af9d0438e Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Fri, 8 Apr 2022 13:02:13 +0100 Subject: [PATCH 17/60] fixed formatting --- shotover-proxy/tests/helpers/cassandra.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/shotover-proxy/tests/helpers/cassandra.rs b/shotover-proxy/tests/helpers/cassandra.rs index 0ac362a92..700656a89 100644 --- a/shotover-proxy/tests/helpers/cassandra.rs +++ b/shotover-proxy/tests/helpers/cassandra.rs @@ -13,10 +13,9 @@ pub fn cassandra_connection(contact_points: &str, port: u16) -> Session { cluster.set_load_balance_round_robin(); let result = cluster.connect(); if let Some(err) = &result.as_ref().err() { - assert!( false, "{}",err ); + assert!(false, "{}", err); } result.unwrap() - //cluster.connect().unwrap() } #[derive(Debug, Clone, PartialOrd, PartialEq, Eq, Ord)] From 7362e2005ba71f933de759ff17bc01971a908dd3 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Tue, 12 Apr 2022 07:36:25 +0100 Subject: [PATCH 18/60] added handling of multiple statements in CQL --- Cargo.lock | 46 +- shotover-proxy/src/frame/cassandra.rs | 485 +++++++++++------- .../src/transforms/cassandra/peers_rewrite.rs | 35 +- shotover-proxy/src/transforms/mod.rs | 2 + shotover-proxy/src/transforms/protect/mod.rs | 90 ++-- shotover-proxy/src/transforms/redis/cache.rs | 85 +-- .../cassandra_int_tests/basic_driver_tests.rs | 4 +- shotover-proxy/tests/helpers/cassandra.rs | 2 + 8 files changed, 436 insertions(+), 313 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 55da1e6ea..a5d8c4eb8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -438,7 +438,7 @@ dependencies = [ [[package]] name = "cql3_parser" version = "0.0.1" -source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#b2820a666e9e163e511a803fcf884ed2e4e60625" +source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#c2d566bce5f070b93cbbf79a3d6eaa6e58ad15af" dependencies = [ "bigdecimal", "bytes", @@ -1175,9 +1175,9 @@ checksum = "1aab8fc367588b89dcee83ab0fd66b72b50b72fa1904d7095045ace2b0c81c35" [[package]] name = "js-sys" -version = "0.3.56" +version = "0.3.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a38fc24e30fd564ce974c02bf1d337caddff65be6cc4735a1f7eab22a7440f04" +checksum = "671a26f820db17c2a2750743f1dd03bafd15b98c9f30c7c2628c024c05d73397" dependencies = [ "wasm-bindgen", ] @@ -1954,9 +1954,9 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quote" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "632d02bff7f874a36f33ea8bb416cd484b90cc66c1194b1a1110d067a7013f58" +checksum = "a1feb54ed693b93a84e14094943b84b7c4eae204c512b7ccb95ab0c66d278ad1" dependencies = [ "proc-macro2", ] @@ -2975,9 +2975,9 @@ checksum = "360dfd1d6d30e05fda32ace2c8c70e9c0a9da713275777f5a4dbb8a1893930c6" [[package]] name = "tracing" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a1bdf54a7c28a2bbf701e1d2233f6c77f473486b94bee4f9678da5a148dca7f" +checksum = "80b9fa4360528139bc96100c160b7ae879f5567f49f1782b0b02035b0358ebf3" dependencies = [ "cfg-if", "pin-project-lite", @@ -3031,9 +3031,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.10" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9df98b037d039d03400d9dd06b0f8ce05486b5f25e9a2d7d36196e142ebbc52" +checksum = "4bc28f93baff38037f64e6f43d34cfa1605f27a49c34e8a04c5e78b0babf2596" dependencies = [ "ansi_term", "lazy_static", @@ -3196,9 +3196,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.79" +version = "0.2.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25f1af7423d8588a3d840681122e72e6a24ddbcb3f0ec385cac0d12d24256c06" +checksum = "27370197c907c55e3f1a9fbe26f44e937fe6451368324e009cba39e139dc08ad" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -3206,9 +3206,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.79" +version = "0.2.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b21c0df030f5a177f3cba22e9bc4322695ec43e7257d865302900290bcdedca" +checksum = "53e04185bfa3a779273da532f5025e33398409573f348985af9a1cbf3774d3f4" dependencies = [ "bumpalo", "lazy_static", @@ -3221,9 +3221,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.29" +version = "0.4.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2eb6ec270a31b1d3c7e266b999739109abce8b6c87e4b31fcfcd788b65267395" +checksum = "6f741de44b75e14c35df886aff5f1eb73aa114fa5d4d00dcd37b5e01259bf3b2" dependencies = [ "cfg-if", "js-sys", @@ -3233,9 +3233,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.79" +version = "0.2.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f4203d69e40a52ee523b2529a773d5ffc1dc0071801c87b3d270b471b80ed01" +checksum = "17cae7ff784d7e83a2fe7611cfe766ecf034111b49deb850a3dc7699c08251f5" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3243,9 +3243,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.79" +version = "0.2.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa8a30d46208db204854cadbb5d4baf5fcf8071ba5bf48190c3e59937962ebc" +checksum = "99ec0dc7a4756fffc231aab1b9f2f578d23cd391390ab27f952ae0c9b3ece20b" dependencies = [ "proc-macro2", "quote", @@ -3256,15 +3256,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.79" +version = "0.2.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d958d035c4438e28c70e4321a2911302f10135ce78a9c7834c0cab4123d06a2" +checksum = "d554b7f530dee5964d9a9468d95c1f8b8acae4f282807e7d27d4b03099a46744" [[package]] name = "web-sys" -version = "0.3.56" +version = "0.3.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c060b319f29dd25724f09a2ba1418f142f539b2be99fbf4d2d5a8f7330afb8eb" +checksum = "7b17e741662c70c8bd24ac5c5b18de314a2c26c32bf8346ee1e6f53de919c283" dependencies = [ "js-sys", "wasm-bindgen", diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index f11e2d271..b18152696 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -25,15 +25,19 @@ use cql3_parser::cassandra_statement::CassandraStatement; use cql3_parser::common::{Operand, RelationElement}; use cql3_parser::insert::InsertValues; use cql3_parser::update::AssignmentOperator; +use itertools::Itertools; use nonzero_ext::nonzero; use sodiumoxide::hex; use std::convert::TryInto; +use std::fmt::{Display, Formatter}; use std::net::IpAddr; use std::num::NonZeroU32; use std::str::FromStr; +use tracing::info; use uuid::Uuid; use crate::message::{MessageValue, QueryType}; +use crate::message::QueryType::PubSubMessage; /// Extract the length of a BATCH statement (count of requests) from the body bytes fn get_batch_len(bytes: &[u8]) -> Result { @@ -242,6 +246,7 @@ impl CassandraFrame { }) } + /// returns the query type for the current statement. pub fn get_query_type(&self) -> QueryType { /* Read, @@ -251,48 +256,29 @@ impl CassandraFrame { PubSubMessage, */ match &self.operation { - CassandraOperation::Query { query: cql, .. } => match cql.get_statement() { - CassandraStatement::AlterKeyspace(_) => QueryType::SchemaChange, - CassandraStatement::AlterMaterializedView(_) => QueryType::SchemaChange, - CassandraStatement::AlterRole(_) => QueryType::SchemaChange, - CassandraStatement::AlterTable(_) => QueryType::SchemaChange, - CassandraStatement::AlterType(_) => QueryType::SchemaChange, - CassandraStatement::AlterUser(_) => QueryType::SchemaChange, - CassandraStatement::ApplyBatch => QueryType::ReadWrite, - CassandraStatement::CreateAggregate(_) => QueryType::SchemaChange, - CassandraStatement::CreateFunction(_) => QueryType::SchemaChange, - CassandraStatement::CreateIndex(_) => QueryType::SchemaChange, - CassandraStatement::CreateKeyspace(_) => QueryType::SchemaChange, - CassandraStatement::CreateMaterializedView(_) => QueryType::SchemaChange, - CassandraStatement::CreateRole(_) => QueryType::SchemaChange, - CassandraStatement::CreateTable(_) => QueryType::SchemaChange, - CassandraStatement::CreateTrigger(_) => QueryType::SchemaChange, - CassandraStatement::CreateType(_) => QueryType::SchemaChange, - CassandraStatement::CreateUser(_) => QueryType::SchemaChange, - CassandraStatement::Delete(_) => QueryType::Write, - CassandraStatement::DropAggregate(_) => QueryType::SchemaChange, - CassandraStatement::DropFunction(_) => QueryType::SchemaChange, - CassandraStatement::DropIndex(_) => QueryType::SchemaChange, - CassandraStatement::DropKeyspace(_) => QueryType::SchemaChange, - CassandraStatement::DropMaterializedView(_) => QueryType::SchemaChange, - CassandraStatement::DropRole(_) => QueryType::SchemaChange, - CassandraStatement::DropTable(_) => QueryType::SchemaChange, - CassandraStatement::DropTrigger(_) => QueryType::SchemaChange, - CassandraStatement::DropType(_) => QueryType::SchemaChange, - CassandraStatement::DropUser(_) => QueryType::SchemaChange, - CassandraStatement::Grant(_) => QueryType::SchemaChange, - CassandraStatement::Insert(_) => QueryType::Write, - CassandraStatement::ListPermissions(_) => QueryType::Read, - CassandraStatement::ListRoles(_) => QueryType::Read, - CassandraStatement::Revoke(_) => QueryType::SchemaChange, - CassandraStatement::Select(_) => QueryType::Read, - CassandraStatement::Truncate(_) => QueryType::Write, - CassandraStatement::Update(_) => QueryType::Write, - CassandraStatement::Use(_) => QueryType::SchemaChange, - CassandraStatement::Unknown(_) => QueryType::Read, + CassandraOperation::Query { query: cql, .. } => { + // set to lowest type + let mut result = QueryType::SchemaChange; + for cql_statement in &cql.statements { + result = match cql_statement.get_query_type() { + QueryType::ReadWrite => { QueryType::ReadWrite }, + QueryType::Write => { match result { + QueryType::ReadWrite | + QueryType::Write => { result }, + QueryType::Read => { QueryType::ReadWrite }, + QueryType::SchemaChange | + PubSubMessage => { QueryType::Write }, + }}, + QueryType::Read => { if result==QueryType::SchemaChange {QueryType::Read } else { result }}, + QueryType::SchemaChange | + PubSubMessage => { result } + } + } + result }, - _ => QueryType::Read, + _ => QueryType::Read , } + } /// returns a list of table names from the CassandraOperation @@ -300,15 +286,19 @@ impl CassandraFrame { let mut result = vec![]; match &self.operation { CassandraOperation::Query { query: cql, .. } => { - if let Some(name) = CQL::get_table_name(&cql.statement) { - result.push(name.into()); + for cql_statement in &cql.statements { + if let Some(name) = CQLStatement::get_table_name(&cql_statement.statement) { + result.push(name.into()); + } } } CassandraOperation::Batch(batch) => { for q in &batch.queries { if let BatchStatementType::Statement(cql) = &q.ty { - if let Some(name) = CQL::get_table_name(&cql.statement) { - result.push(name.into()); + for cql_statement in &cql.statements { + if let Some(name) = CQLStatement::get_table_name(&cql_statement.statement) { + result.push(name.into()); + } } } } @@ -368,8 +358,10 @@ impl CassandraOperation { } */ if let CassandraOperation::Query { query: cql, .. } = self { - result.push(&mut *cql.statement) - }; + for cql_statement in &mut cql.statements { + result.push(&mut cql_statement.statement) + } + } result } @@ -498,142 +490,95 @@ impl CassandraOperation { } #[derive(PartialEq, Debug, Clone)] -pub struct CQL { - statement: Box, - pub has_error: bool, +pub struct CQLStatement { + pub(crate) statement: CassandraStatement, + has_error: bool, } -impl CQL { - pub fn get_statement_mut(&mut self) -> &mut CassandraStatement { - self.statement.as_mut() - } - - pub fn get_statement(&self) -> &CassandraStatement { - self.statement.as_ref() - } - - pub fn clone_statement(&self) -> CassandraStatement { - self.get_statement().clone() - } +impl CQLStatement { - fn from_value_and_col_spec(value: &Value, col_spec: &ColSpec) -> Operand { - match value { - Value::Some(vec) => { - let cbytes = CBytes::new(vec.clone()); - let message_value = - MessageValue::build_value_from_cstar_col_type(col_spec, &cbytes); - let pmsg_value = &message_value; - pmsg_value.into() - } - Value::Null => Operand::Null, - Value::NotSet => Operand::Null, - } - } - fn set_param_value_by_name( - name: &str, - query_params: &QueryParams, - param_types: &[ColSpec], - ) -> Operand { - if let Some(QueryValues::NamedValues(value_map)) = &query_params.values { - if let Some(value) = value_map.get(name) { - if let Some(idx) = value_map - .iter() - .enumerate() - .filter_map( - |(idx, (key, _value))| { - if key.eq(name) { - Some(idx) - } else { - None - } - }, - ) - .next() - { - return CQL::from_value_and_col_spec(value, ¶m_types[idx]); - } - } + pub fn is_begin_batch(&self) -> bool { + match &self.statement { + CassandraStatement::Delete(delete) => delete.begin_batch.is_some(), + CassandraStatement::Insert(insert) => insert.begin_batch.is_some(), + CassandraStatement::Update(update) => update.begin_batch.is_some(), + _ => false, } - Operand::Param(format!(":{}", name)) } - fn set_param_value_by_position( - param_idx: &mut usize, - query_params: &QueryParams, - param_types: &[ColSpec], - ) -> Operand { - if let Some(QueryValues::SimpleValues(values)) = &query_params.values { - if let Some(value) = values.get(*param_idx) { - *param_idx += 1; - CQL::from_value_and_col_spec(value, ¶m_types[*param_idx]) - } else { - *param_idx += 1; - Operand::Param("?".into()) - } - } else { - *param_idx += 1; - Operand::Param("?".into()) + pub fn is_apply_batch(&self) -> bool { + match &self.statement { + CassandraStatement::ApplyBatch => true, + _ => false, } } - fn set_operand_if_param( - operand: &Operand, - param_idx: &mut usize, - query_params: &QueryParams, - param_types: &[ColSpec], - ) -> Operand { - match operand { - Operand::Tuple(vec) => { - let mut vec2 = Vec::with_capacity(vec.len()); - vec.iter().for_each(|o| { - vec2.push(CQL::set_operand_if_param( - o, - param_idx, - query_params, - param_types, - )) - }); - - Operand::Tuple(vec2) - } - Operand::Param(param_name) => { - if param_name.starts_with('?') { - CQL::set_param_value_by_position(param_idx, query_params, param_types) - } else { - let name = param_name.split_at(0).1; - CQL::set_param_value_by_name(name, query_params, param_types) - } - } - Operand::Collection(vec) => { - let mut vec2 = Vec::with_capacity(vec.len()); - vec.iter().for_each(|o| { - vec2.push(CQL::set_operand_if_param( - o, - param_idx, - query_params, - param_types, - )) - }); - - Operand::Collection(vec2) - } - _ => operand.clone(), + /// returns the query type for the current statement. + pub fn get_query_type(&self) -> QueryType { + /* + Read, + Write, + ReadWrite, + SchemaChange, + PubSubMessage, + */ + match &self.statement { + CassandraStatement::AlterKeyspace(_) => QueryType::SchemaChange, + CassandraStatement::AlterMaterializedView(_) => QueryType::SchemaChange, + CassandraStatement::AlterRole(_) => QueryType::SchemaChange, + CassandraStatement::AlterTable(_) => QueryType::SchemaChange, + CassandraStatement::AlterType(_) => QueryType::SchemaChange, + CassandraStatement::AlterUser(_) => QueryType::SchemaChange, + CassandraStatement::ApplyBatch => QueryType::ReadWrite, + CassandraStatement::CreateAggregate(_) => QueryType::SchemaChange, + CassandraStatement::CreateFunction(_) => QueryType::SchemaChange, + CassandraStatement::CreateIndex(_) => QueryType::SchemaChange, + CassandraStatement::CreateKeyspace(_) => QueryType::SchemaChange, + CassandraStatement::CreateMaterializedView(_) => QueryType::SchemaChange, + CassandraStatement::CreateRole(_) => QueryType::SchemaChange, + CassandraStatement::CreateTable(_) => QueryType::SchemaChange, + CassandraStatement::CreateTrigger(_) => QueryType::SchemaChange, + CassandraStatement::CreateType(_) => QueryType::SchemaChange, + CassandraStatement::CreateUser(_) => QueryType::SchemaChange, + CassandraStatement::Delete(_) => QueryType::Write, + CassandraStatement::DropAggregate(_) => QueryType::SchemaChange, + CassandraStatement::DropFunction(_) => QueryType::SchemaChange, + CassandraStatement::DropIndex(_) => QueryType::SchemaChange, + CassandraStatement::DropKeyspace(_) => QueryType::SchemaChange, + CassandraStatement::DropMaterializedView(_) => QueryType::SchemaChange, + CassandraStatement::DropRole(_) => QueryType::SchemaChange, + CassandraStatement::DropTable(_) => QueryType::SchemaChange, + CassandraStatement::DropTrigger(_) => QueryType::SchemaChange, + CassandraStatement::DropType(_) => QueryType::SchemaChange, + CassandraStatement::DropUser(_) => QueryType::SchemaChange, + CassandraStatement::Grant(_) => QueryType::SchemaChange, + CassandraStatement::Insert(_) => QueryType::Write, + CassandraStatement::ListPermissions(_) => QueryType::Read, + CassandraStatement::ListRoles(_) => QueryType::Read, + CassandraStatement::Revoke(_) => QueryType::SchemaChange, + CassandraStatement::Select(_) => QueryType::Read, + CassandraStatement::Truncate(_) => QueryType::Write, + CassandraStatement::Update(_) => QueryType::Write, + CassandraStatement::Use(_) => QueryType::SchemaChange, + CassandraStatement::Unknown(_) => QueryType::Read, } } - fn set_relation_elements_values( - param_idx: &mut usize, - query_params: &QueryParams, - param_types: &[ColSpec], - where_clause: &mut [RelationElement], - ) { - for relation_element in where_clause { - relation_element.value = CQL::set_operand_if_param( - &relation_element.value, - param_idx, - query_params, - param_types, - ); + /// returns the table name specified in the command if one is present. + pub fn get_table_name(statement: &CassandraStatement) -> Option<&String> { + match statement { + CassandraStatement::AlterTable(t) => Some(&t.name), + CassandraStatement::CreateIndex(i) => Some(&i.table), + CassandraStatement::CreateMaterializedView(m) => Some(&m.table), + CassandraStatement::CreateTable(t) => Some(&t.name), + CassandraStatement::Delete(d) => Some(&d.table_name), + CassandraStatement::DropTable(t) => Some(&t.name), + CassandraStatement::DropTrigger(t) => Some(&t.table), + CassandraStatement::Insert(i) => Some(&i.table_name), + CassandraStatement::Select(s) => Some(&s.table_name), + CassandraStatement::Truncate(t) => Some(t), + CassandraStatement::Update(u) => Some(&u.table_name), + _ => None, } } @@ -646,7 +591,7 @@ impl CQL { param_types: &[ColSpec], ) -> CassandraStatement { let mut param_idx: usize = 0; - let mut statement = self.clone_statement(); + let mut statement = self.statement.clone(); match &mut statement { CassandraStatement::Delete(delete) => { CQL::set_relation_elements_values( @@ -734,7 +679,7 @@ impl CQL { match operand { Operand::Tuple(vec) | Operand::Collection(vec) => { for oper in vec { - if CQL::has_params_in_operand(oper) { + if CQLStatement::has_params_in_operand(oper) { return true; } } @@ -747,7 +692,7 @@ impl CQL { fn has_params_in_relation_elements(where_clause: &[RelationElement]) -> bool { for relation_idx in where_clause { - if CQL::has_params_in_operand(&relation_idx.value) { + if CQLStatement::has_params_in_operand(&relation_idx.value) { return true; } } @@ -758,10 +703,10 @@ impl CQL { pub fn has_params(statement: &CassandraStatement) -> bool { match statement { CassandraStatement::Delete(delete) => { - if CQL::has_params_in_relation_elements(&delete.where_clause) { + if CQLStatement::has_params_in_relation_elements(&delete.where_clause) { return true; } - if CQL::has_params_in_relation_elements(&delete.if_clause) { + if CQLStatement::has_params_in_relation_elements(&delete.if_clause) { return true; } } @@ -775,7 +720,7 @@ impl CQL { } } CassandraStatement::Select(select) => { - return CQL::has_params_in_relation_elements(&select.where_clause); + return CQLStatement::has_params_in_relation_elements(&select.where_clause); } CassandraStatement::Update(update) => { for assignment_element in &update.assignments { @@ -797,10 +742,10 @@ impl CQL { } } } - if CQL::has_params_in_relation_elements(&update.where_clause) { + if CQLStatement::has_params_in_relation_elements(&update.where_clause) { return true; } - if CQL::has_params_in_relation_elements(&update.if_clause) { + if CQLStatement::has_params_in_relation_elements(&update.if_clause) { return true; } } @@ -809,35 +754,179 @@ impl CQL { false } +} + +impl Display for CQLStatement { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.statement.fmt( f ) + } +} + +#[dervie(Copy,Clone)] +struct CQLStatementListWrapper<'a> { + statement_ref : &'a Vec> +} + +#[derive(PartialEq, Debug, Clone)] +pub struct CQL { + pub statements: Vec>, + has_error: bool, +} + + + +impl CQL { + /// the number of statements in the CQL + pub fn get_statement_count(&self) -> usize { + self.statements.len() + } + + fn from_value_and_col_spec(value: &Value, col_spec: &ColSpec) -> Operand { + match value { + Value::Some(vec) => { + let cbytes = CBytes::new(vec.clone()); + let message_value = + MessageValue::build_value_from_cstar_col_type(col_spec, &cbytes); + let pmsg_value = &message_value; + pmsg_value.into() + } + Value::Null => Operand::Null, + Value::NotSet => Operand::Null, + } + } + fn set_param_value_by_name( + name: &str, + query_params: &QueryParams, + param_types: &[ColSpec], + ) -> Operand { + if let Some(QueryValues::NamedValues(value_map)) = &query_params.values { + if let Some(value) = value_map.get(name) { + if let Some(idx) = value_map + .iter() + .enumerate() + .filter_map( + |(idx, (key, _value))| { + if key.eq(name) { + Some(idx) + } else { + None + } + }, + ) + .next() + { + return CQL::from_value_and_col_spec(value, ¶m_types[idx]); + } + } + } + Operand::Param(format!(":{}", name)) + } + + fn set_param_value_by_position( + param_idx: &mut usize, + query_params: &QueryParams, + param_types: &[ColSpec], + ) -> Operand { + if let Some(QueryValues::SimpleValues(values)) = &query_params.values { + if let Some(value) = values.get(*param_idx) { + *param_idx += 1; + CQL::from_value_and_col_spec(value, ¶m_types[*param_idx]) + } else { + *param_idx += 1; + Operand::Param("?".into()) + } + } else { + *param_idx += 1; + Operand::Param("?".into()) + } + } + + fn set_operand_if_param( + operand: &Operand, + param_idx: &mut usize, + query_params: &QueryParams, + param_types: &[ColSpec], + ) -> Operand { + match operand { + Operand::Tuple(vec) => { + let mut vec2 = Vec::with_capacity(vec.len()); + vec.iter().for_each(|o| { + vec2.push(CQL::set_operand_if_param( + o, + param_idx, + query_params, + param_types, + )) + }); + + Operand::Tuple(vec2) + } + Operand::Param(param_name) => { + if param_name.starts_with('?') { + CQL::set_param_value_by_position(param_idx, query_params, param_types) + } else { + let name = param_name.split_at(0).1; + CQL::set_param_value_by_name(name, query_params, param_types) + } + } + Operand::Collection(vec) => { + let mut vec2 = Vec::with_capacity(vec.len()); + vec.iter().for_each(|o| { + vec2.push(CQL::set_operand_if_param( + o, + param_idx, + query_params, + param_types, + )) + }); + + Operand::Collection(vec2) + } + _ => operand.clone(), + } + } + + fn set_relation_elements_values( + param_idx: &mut usize, + query_params: &QueryParams, + param_types: &[ColSpec], + where_clause: &mut [RelationElement], + ) { + for relation_element in where_clause { + relation_element.value = CQL::set_operand_if_param( + &relation_element.value, + param_idx, + query_params, + param_types, + ); + } + } + + pub fn to_query_string(&self) -> String { - self.statement.to_string() + self.statements + .iter() + .map(|c| c.statement.to_string()) + .join("; ") } /// the CassandraAST handles multiple queries in a string separated by semi-colons: `;` however /// CQL only stores one query so this method only returns the first one if there are multiples. pub fn parse_from_string(cql_query_str: &str) -> Self { + info!("parse_from_string: {}", cql_query_str); let ast = CassandraAST::new(cql_query_str); + + let mut vec = Vec::with_capacity(ast.statements.len()); + + for statement in &ast.statements { + vec.push(Box::new(CQLStatement { + has_error: statement.0, + statement: statement.1.clone(), + })); + } CQL { has_error: ast.has_error(), - statement: Box::new(ast.statements.first().unwrap().clone()), - } - } - - /// returns the table name specified in the command if one is present. - pub fn get_table_name(statement: &CassandraStatement) -> Option<&String> { - match statement { - CassandraStatement::AlterTable(t) => Some(&t.name), - CassandraStatement::CreateIndex(i) => Some(&i.table), - CassandraStatement::CreateMaterializedView(m) => Some(&m.table), - CassandraStatement::CreateTable(t) => Some(&t.name), - CassandraStatement::Delete(d) => Some(&d.table_name), - CassandraStatement::DropTable(t) => Some(&t.name), - CassandraStatement::DropTrigger(t) => Some(&t.table), - CassandraStatement::Insert(i) => Some(&i.table_name), - CassandraStatement::Select(s) => Some(&s.table_name), - CassandraStatement::Truncate(t) => Some(t), - CassandraStatement::Update(u) => Some(&u.table_name), - _ => None, + statements: vec, } } } diff --git a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs index 00500be5d..6a960f7fe 100644 --- a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs @@ -1,4 +1,4 @@ -use crate::frame::{CassandraOperation, CassandraResult, Frame, CQL}; +use crate::frame::{CassandraOperation, CassandraResult, Frame}; use crate::message::{IntSize, Message, MessageValue}; use crate::{ error::ChainResponse, @@ -10,6 +10,7 @@ use cql3_parser::cassandra_statement::CassandraStatement; use cql3_parser::select::SelectElement; use serde::Deserialize; use std::collections::HashMap; +use crate::frame::cassandra::CQLStatement; #[derive(Deserialize, Debug, Clone)] pub struct CassandraPeersRewriteConfig { @@ -72,22 +73,24 @@ fn extract_native_port_column(message: &mut Message) -> Vec { let mut result: Vec = vec![]; if let Some(Frame::Cassandra(cassandra)) = message.frame() { if let CassandraOperation::Query { query, .. } = &cassandra.operation { - let statement = query.get_statement(); - if let CassandraStatement::Select(select) = &statement { - if let Some(table_name) = CQL::get_table_name(statement) { - if table_name.eq("system.peers_v2") { - select - .columns - .iter() - .for_each(|select_element| match select_element { - SelectElement::Column(col_name) => { - if col_name.name.eq("native_port") { - result.push(col_name.alias_or_name()); + for cql_statement in &query.statements { + let statement = &cql_statement.statement; + if let CassandraStatement::Select(select) = &statement { + if let Some(table_name) = CQLStatement::get_table_name(&statement) { + if table_name.eq("system.peers_v2") { + select + .columns + .iter() + .for_each(|select_element| match select_element { + SelectElement::Column(col_name) => { + if col_name.name.eq("native_port") { + result.push(col_name.alias_or_name()); + } } - } - SelectElement::Star => result.push("native_port".to_string()), - _ => {} - }); + SelectElement::Star => result.push("native_port".to_string()), + _ => {} + }); + } } } } diff --git a/shotover-proxy/src/transforms/mod.rs b/shotover-proxy/src/transforms/mod.rs index 300780a0f..ebf68411c 100644 --- a/shotover-proxy/src/transforms/mod.rs +++ b/shotover-proxy/src/transforms/mod.rs @@ -304,6 +304,7 @@ pub async fn build_chain_from_config( } use std::slice::IterMut; +use tracing::info; /// The [`Wrapper`] struct is passed into each transform and contains a list of mutable references to the /// remaining transforms that will process the messages attached to this [`Wrapper`]. @@ -362,6 +363,7 @@ impl<'a> Wrapper<'a> { let transform_name = transform.get_name(); let chain_name = self.chain_name.clone(); + info!( "call_next_transform calling {} {}", transform_name, chain_name ); let start = Instant::now(); let result = CONTEXT_CHAIN_NAME diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index 5509cbede..90a9d2111 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -1,5 +1,5 @@ use crate::error::ChainResponse; -use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame, CQL}; +use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; use crate::message::MessageValue; use crate::transforms::protect::key_management::{KeyManager, KeyManagerConfig}; use crate::transforms::{Transform, Transforms, Wrapper}; @@ -17,6 +17,7 @@ use sodiumoxide::crypto::secretbox::{Key, Nonce}; use sodiumoxide::hex; use std::collections::HashMap; use tracing::warn; +use crate::frame::cassandra::CQLStatement; mod aws_kms; mod key_management; @@ -238,14 +239,17 @@ impl Transform for Protect { .. })) = message.frame() { - let statement = query.get_statement_mut(); - if let Some(table_name) = CQL::get_table_name(statement) { - if let Some((_, tables)) = self.keyspace_table_columns.get_key_value(table_name) - { - if let Some((_, columns)) = tables.get_key_value(table_name) { - data_changed = - encrypt_columns(statement, columns, &self.key_source, &self.key_id) - .await?; + for cql_statement in &mut query.statements { + let statement = &mut cql_statement.statement; + + if let Some(table_name) = CQLStatement::get_table_name(statement) { + if let Some((_, tables)) = self.keyspace_table_columns.get_key_value(table_name) + { + if let Some((_, columns)) = tables.get_key_value(table_name) { + data_changed = + encrypt_columns(statement, columns, &self.key_source, &self.key_id) + .await?; + } } } } @@ -275,45 +279,47 @@ impl Transform for Protect { .. })) = request.frame() { - let statement = query.get_statement(); - if let Some(table_name) = CQL::get_table_name(statement) { - if let Some((_keyspace, tables)) = + for cql_statement in &mut query.statements { + let statement = &mut cql_statement.statement; + if let Some(table_name) = CQLStatement::get_table_name(statement) { + if let Some((_keyspace, tables)) = self.keyspace_table_columns.get_key_value(table_name) - { - if let Some((_table, protect_columns)) = - tables.get_key_value(table_name) { - if let CassandraStatement::Select(select) = &statement { - let positions: Vec = select - .columns - .iter() - .enumerate() - .filter_map(|(i, col)| { - if let SelectElement::Column(named) = col { - if protect_columns.contains(&named.name) { - Some(i) + if let Some((_table, protect_columns)) = + tables.get_key_value(table_name) + { + if let CassandraStatement::Select(select) = &statement { + let positions: Vec = select + .columns + .iter() + .enumerate() + .filter_map(|(i, col)| { + if let SelectElement::Column(named) = col { + if protect_columns.contains(&named.name) { + Some(i) + } else { + None + } } else { None } - } else { - None - } - }) - .collect(); - for row in &mut *rows { - for index in &positions { - if let Some(v) = row.get_mut(*index) { - if let MessageValue::Bytes(_) = v { - let protected = - Protected::from_encrypted_bytes_value(v) + }) + .collect(); + for row in &mut *rows { + for index in &positions { + if let Some(v) = row.get_mut(*index) { + if let MessageValue::Bytes(_) = v { + let protected = + Protected::from_encrypted_bytes_value(v) + .await?; + let new_value: MessageValue = protected + .unprotect(&self.key_source, &self.key_id) .await?; - let new_value: MessageValue = protected - .unprotect(&self.key_source, &self.key_id) - .await?; - *v = new_value; - invalidate_cache = true; - } else { - warn!("Tried decrypting non-blob column") + *v = new_value; + invalidate_cache = true; + } else { + warn!("Tried decrypting non-blob column") + } } } } diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index bcf54e4cb..9440ecc99 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -1,6 +1,6 @@ use crate::config::topology::TopicHolder; use crate::error::ChainResponse; -use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame, RedisFrame, CQL}; +use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame, RedisFrame}; use crate::message::{Message, Messages, QueryType}; use crate::transforms::chain::TransformChain; use crate::transforms::{ @@ -12,11 +12,11 @@ use bytes::{BufMut, Bytes, BytesMut}; use cassandra_protocol::frame::Version; use cql3_parser::cassandra_statement::CassandraStatement; use cql3_parser::common::{Operand, RelationElement, RelationOperator}; -use cql3_parser::select::SelectElement; use itertools::Itertools; use serde::Deserialize; use std::collections::{BTreeMap, HashMap}; use tracing_log::log::info; +use crate::frame::cassandra::CQLStatement; enum CacheableState { Read, @@ -83,6 +83,7 @@ impl SimpleRedisCache { // + if the request is a CassandraOperation::Query then we consume a single message from messages_redis_response converting it to a cassandra response // * These are the cassandra responses that we return from the function. + info!("get_or_update_from_cache called"); let mut messages_redis_request = Vec::with_capacity(messages_cass_request.len()); for cass_request in &mut messages_cass_request { match cass_request.frame() { @@ -96,7 +97,8 @@ impl SimpleRedisCache { CacheableState::Read | CacheableState::Update | CacheableState::Delete => { - if let Some(table_name) = CQL::get_table_name(statement) { + info!("get_or_update_from_cache processing cacheable state"); + if let Some(table_name) = CQLStatement::get_table_name(statement) { if let Some(table_cache_schema) = self.caching_schema.get(table_name) { @@ -121,6 +123,7 @@ impl SimpleRedisCache { } } _ => { + info!("get_or_update_from_cache not processing cacheable state"); // do nothing here but check again again outside of match as state may have changed } } @@ -142,6 +145,8 @@ impl SimpleRedisCache { } } + info!("get_or_update_from_cache calling cache_chain.process_request"); + match self .cache_chain .process_request( @@ -151,6 +156,7 @@ impl SimpleRedisCache { .await { Ok(messages_redis_response) => { + info!("get_or_update_from_cache received OK from cache_chain.process_request"); // Replace cass_request messages with cassandra responses in place. // We reuse the vec like this to save allocations. let mut messages_redis_response_iter = messages_redis_response.into_iter(); @@ -250,7 +256,7 @@ fn build_zrangebylex_min_max_from_sql( } fn is_cacheable(statement: &CassandraStatement) -> CacheableState { - let has_params = CQL::has_params(statement); + let has_params = CQLStatement::has_params(statement); match statement { CassandraStatement::Select(select) => { @@ -260,14 +266,16 @@ fn is_cacheable(statement: &CassandraStatement) -> CacheableState { CacheableState::Skip("Can not cache with ALLOW FILTERING".into()) } else if select.where_clause.is_empty() { CacheableState::Skip("Can not cache if where clause is empty".into()) - } else if !select.columns.is_empty() { + /* } else if !select.columns.is_empty() { if select.columns.len() == 1 && select.columns[0].eq(&SelectElement::Star) { CacheableState::Read } else { CacheableState::Skip( - "Can not cache if columns other than '*' are not selected".into(), + "Can not cache if columns other than '*' are selected".into(), ) } + + */ } else { CacheableState::Read } @@ -495,38 +503,51 @@ impl Transform for SimpleRedisCache { async fn transform<'a>(&'a mut self, mut message_wrapper: Wrapper<'a>) -> ChainResponse { let mut read_cache = true; for m in &mut message_wrapper.messages { - if let Some(Frame::Cassandra(CassandraFrame { - operation: CassandraOperation::Query { .. }, - .. - })) = m.frame() - { - /* let statement = query.get_statement(); - match is_cacheable(statement ) { - CacheableState::Read | - CacheableState::Update | - CacheableState::Delete => {} - CacheableState::Skip(reason) => { - tracing::info!( "Cache skipped for {} due to {}", statement, reason ); - use_cache = false; + if let Some(&mut Frame::Cassandra(CassandraFrame{ operation : CassandraOperation::Query{ query , ..},..})) = &mut m.frame() { + + + + + //if let Some(&mut Frame::Cassandra(CassandraFrame { + // operation: &CassandraOperation::Query { query, .. }, + // .. + // })) = &mut m.frame() + // { + /* let statement = query.get_statement(); + match is_cacheable(statement ) { + CacheableState::Read | + CacheableState::Update | + CacheableState::Delete => {} + CacheableState::Skip(reason) => { + tracing::info!( "Cache skipped for {} due to {}", statement, reason ); + use_cache = false; + } + CacheableState::Err(reason) => { + tracing::error!("Cache failed for {} due to {}", statement, reason); + use_cache = false; + } } - CacheableState::Err(reason) => { - tracing::error!("Cache failed for {} due to {}", statement, reason); - use_cache = false; + + */ + for cql_statement in query.statements { + info!("cache transform processing {}", cql_statement); + match cql_statement.get_query_type() { + QueryType::Read => {} + QueryType::Write => read_cache = false, + QueryType::ReadWrite => read_cache = false, + QueryType::SchemaChange => read_cache = false, + QueryType::PubSubMessage => {} + } } - } - */ - match m.get_query_type() { - QueryType::Read => {} - QueryType::Write => read_cache = false, - QueryType::ReadWrite => read_cache = false, - QueryType::SchemaChange => read_cache = false, - QueryType::PubSubMessage => {} - } + + } else { + read_cache = false; } } + info!("cache transform read_cache:{} ", read_cache); - // If there are no write queries (all queries are reads) we can use the cache + // If there are no write queries (all queries are reads) we can read the cache if read_cache { match self .get_or_update_from_cache(message_wrapper.messages.clone()) diff --git a/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs b/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs index 7192b4812..e844afaab 100644 --- a/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs +++ b/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs @@ -1124,7 +1124,7 @@ mod cache { // query against some other field assert_query_result( cassandra_session, - "SELECT id, x, name FROM test_cache_keyspace_batch_insert.test_table WHERE x=11", + "SELECT id, x, name FROM test_cache_keyspace_batch_insert.test_table WHERE x=11 ALLOW FILTERING", &[], ); @@ -1191,7 +1191,7 @@ mod cache { // query against some other field assert_query_result( cassandra_session, - "SELECT id, x, name FROM test_cache_keyspace_simple.test_table WHERE x=11", + "SELECT id, x, name FROM test_cache_keyspace_simple.test_table WHERE x=11 ALLOW FILTERING", &[], ); diff --git a/shotover-proxy/tests/helpers/cassandra.rs b/shotover-proxy/tests/helpers/cassandra.rs index 700656a89..df4a0fc91 100644 --- a/shotover-proxy/tests/helpers/cassandra.rs +++ b/shotover-proxy/tests/helpers/cassandra.rs @@ -1,5 +1,6 @@ use cassandra_cpp::{stmt, Cluster, Error, Session, Value, ValueType}; use ordered_float::OrderedFloat; +use tracing::info; use test_helpers::try_wait_for_socket_to_open; pub fn cassandra_connection(contact_points: &str, port: u16) -> Session { @@ -108,6 +109,7 @@ impl ResultValue { #[allow(unused)] pub fn execute_query(session: &Session, query: &str) -> Vec> { let statement = stmt!(query); + info!( "executing query: {}", query); match session.execute(&statement).wait() { Ok(result) => result .into_iter() From 11c91f9329b6bea3e99a974381f1c2e9a4fb1d27 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Tue, 12 Apr 2022 07:48:46 +0100 Subject: [PATCH 19/60] fixed some coding issues --- shotover-proxy/src/frame/cassandra.rs | 5 ----- shotover-proxy/src/transforms/redis/cache.rs | 8 ++++---- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index b18152696..d8afaadc4 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -762,11 +762,6 @@ impl Display for CQLStatement { } } -#[dervie(Copy,Clone)] -struct CQLStatementListWrapper<'a> { - statement_ref : &'a Vec> -} - #[derive(PartialEq, Debug, Clone)] pub struct CQL { pub statements: Vec>, diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index 9440ecc99..0628937e5 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -500,10 +500,10 @@ fn add_query_values( #[async_trait] impl Transform for SimpleRedisCache { - async fn transform<'a>(&'a mut self, mut message_wrapper: Wrapper<'a>) -> ChainResponse { + async fn transform<'a>(&'a mut self, message_wrapper: Wrapper<'a>) -> ChainResponse { let mut read_cache = true; - for m in &mut message_wrapper.messages { - if let Some(&mut Frame::Cassandra(CassandraFrame{ operation : CassandraOperation::Query{ query , ..},..})) = &mut m.frame() { + for m in &message_wrapper.messages { + if let Some(Frame::Cassandra(CassandraFrame{ operation : CassandraOperation::Query{ query , ..},..})) = m.frame() { @@ -529,7 +529,7 @@ impl Transform for SimpleRedisCache { } */ - for cql_statement in query.statements { + for cql_statement in &query.statements { info!("cache transform processing {}", cql_statement); match cql_statement.get_query_type() { QueryType::Read => {} From d8d1c41afbea897b92927ab47edc3362822cb1b7 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Thu, 14 Apr 2022 12:35:09 +0100 Subject: [PATCH 20/60] compilation works again --- Cargo.lock | 6 +- shotover-proxy/src/frame/cassandra.rs | 205 ++++- .../src/transforms/cassandra/peers_rewrite.rs | 2 +- shotover-proxy/src/transforms/mod.rs | 5 +- shotover-proxy/src/transforms/protect/mod.rs | 30 +- shotover-proxy/src/transforms/redis/cache.rs | 787 +++++++++--------- shotover-proxy/tests/helpers/cassandra.rs | 4 +- 7 files changed, 591 insertions(+), 448 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a5d8c4eb8..99da3b555 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -438,7 +438,7 @@ dependencies = [ [[package]] name = "cql3_parser" version = "0.0.1" -source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#c2d566bce5f070b93cbbf79a3d6eaa6e58ad15af" +source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#34cebfc9972c8aab0ce2d31c894c9d338ccdbb06" dependencies = [ "bigdecimal", "bytes", @@ -1054,9 +1054,9 @@ dependencies = [ [[package]] name = "httparse" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9100414882e15fb7feccb4897e5f0ff0ff1ca7d1a86a23208ada4d7a18e6c6c4" +checksum = "6330e8a36bd8c859f3fa6d9382911fbb7147ec39807f63b923933a247240b9ba" [[package]] name = "httpdate" diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index d8afaadc4..3dcd4c4e1 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -8,13 +8,8 @@ use cassandra_protocol::frame::frame_error::ErrorBody; use cassandra_protocol::frame::frame_query::BodyReqQuery; use cassandra_protocol::frame::frame_request::RequestBody; use cassandra_protocol::frame::frame_response::ResponseBody; -use cassandra_protocol::frame::frame_result::{ - BodyResResultPrepared, BodyResResultRows, BodyResResultSetKeyspace, ColSpec, ResResultBody, - RowsMetadata, RowsMetadataFlags, -}; -use cassandra_protocol::frame::{ - Direction, Flags, Frame as RawCassandraFrame, Opcode, Serialize, StreamId, Version, -}; +use cassandra_protocol::frame::frame_result::{BodyResResultPrepared, BodyResResultRows, BodyResResultSetKeyspace, ColSpec, ResResultBody, RowsMetadata, RowsMetadataFlags}; +use cassandra_protocol::frame::{Direction, Flags, Frame as RawCassandraFrame, Opcode, Serialize, StreamId, Version}; use cassandra_protocol::query::{QueryParams, QueryValues}; use cassandra_protocol::types::blob::Blob; use cassandra_protocol::types::cassandra_type::CassandraType; @@ -30,14 +25,15 @@ use nonzero_ext::nonzero; use sodiumoxide::hex; use std::convert::TryInto; use std::fmt::{Display, Formatter}; +use std::io::Cursor; use std::net::IpAddr; use std::num::NonZeroU32; use std::str::FromStr; use tracing::info; use uuid::Uuid; -use crate::message::{MessageValue, QueryType}; use crate::message::QueryType::PubSubMessage; +use crate::message::{MessageValue, QueryType}; /// Extract the length of a BATCH statement (count of requests) from the body bytes fn get_batch_len(bytes: &[u8]) -> Result { @@ -256,29 +252,31 @@ impl CassandraFrame { PubSubMessage, */ match &self.operation { - CassandraOperation::Query { query: cql, .. } => { + CassandraOperation::Query { query: cql, .. } => { // set to lowest type let mut result = QueryType::SchemaChange; for cql_statement in &cql.statements { result = match cql_statement.get_query_type() { - QueryType::ReadWrite => { QueryType::ReadWrite }, - QueryType::Write => { match result { - QueryType::ReadWrite | - QueryType::Write => { result }, - QueryType::Read => { QueryType::ReadWrite }, - QueryType::SchemaChange | - PubSubMessage => { QueryType::Write }, - }}, - QueryType::Read => { if result==QueryType::SchemaChange {QueryType::Read } else { result }}, - QueryType::SchemaChange | - PubSubMessage => { result } + QueryType::ReadWrite => QueryType::ReadWrite, + QueryType::Write => match result { + QueryType::ReadWrite | QueryType::Write => result, + QueryType::Read => QueryType::ReadWrite, + QueryType::SchemaChange | PubSubMessage => QueryType::Write, + }, + QueryType::Read => { + if result == QueryType::SchemaChange { + QueryType::Read + } else { + result + } + } + QueryType::SchemaChange | PubSubMessage => result, } } result - }, - _ => QueryType::Read , + } + _ => QueryType::Read, } - } /// returns a list of table names from the CassandraOperation @@ -296,7 +294,9 @@ impl CassandraFrame { for q in &batch.queries { if let BatchStatementType::Statement(cql) = &q.ty { for cql_statement in &cql.statements { - if let Some(name) = CQLStatement::get_table_name(&cql_statement.statement) { + if let Some(name) = + CQLStatement::get_table_name(&cql_statement.statement) + { result.push(name.into()); } } @@ -365,6 +365,27 @@ impl CassandraOperation { result } + /// Return all queries contained within CassandaOperation::Query and CassandraOperation::Batch + /// An Err is returned if the operation cannot contain queries or the queries failed to parse. + /// + /// TODO: This will return a custom iterator type when BATCH support is added + pub fn get_cql_statements(&mut self) -> Vec<&mut Box> { + let mut result = vec![]; + /* + match self { + CassandraOperation::Query { query: cql, .. } => result.push( &mut *cql.statement), + // TODO: Return CassandraOperation::Batch queries once we add BATCH parsing to cassandra-protocol + _ => { } + } + */ + if let CassandraOperation::Query { query: cql, .. } = self { + for cql_statement in &mut cql.statements { + result.push(cql_statement) + } + } + result + } + fn to_direction(&self) -> Direction { match self { CassandraOperation::Query { .. } => Direction::Request, @@ -491,12 +512,11 @@ impl CassandraOperation { #[derive(PartialEq, Debug, Clone)] pub struct CQLStatement { - pub(crate) statement: CassandraStatement, - has_error: bool, + pub statement: CassandraStatement, + pub has_error: bool, } impl CQLStatement { - pub fn is_begin_batch(&self) -> bool { match &self.statement { CassandraStatement::Delete(delete) => delete.begin_batch.is_some(), @@ -753,23 +773,20 @@ impl CQLStatement { } false } - } impl Display for CQLStatement { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - self.statement.fmt( f ) + self.statement.fmt(f) } } #[derive(PartialEq, Debug, Clone)] pub struct CQL { pub statements: Vec>, - has_error: bool, + pub(crate) has_error: bool, } - - impl CQL { /// the number of statements in the CQL pub fn get_statement_count(&self) -> usize { @@ -897,7 +914,6 @@ impl CQL { } } - pub fn to_query_string(&self) -> String { self.statements .iter() @@ -1014,6 +1030,85 @@ pub enum CassandraResult { Void, } +impl Serialize for CassandraResult { + fn serialize(&self, cursor: &mut Cursor<&mut Vec>) { + let res_result_body : ResResultBody = match self { + CassandraResult::Rows { value,metadata } => { + match value { + MessageValue::Rows(rows) => { + let mut rows_content: Vec> = Vec::with_capacity(rows.len()); + for row in rows { + let mut row_data = Vec::with_capacity(row.len()); + for element in row { + let b = cassandra_protocol::types::value::Bytes::from(element.clone()); + row_data.push(CBytes::new(b.into_inner())); + } + rows_content.push(row_data); + } + let body_res_result_rows = BodyResResultRows { + metadata: metadata.clone(), + rows_count: rows.len() as CInt, + rows_content + }; + ResResultBody::Rows(body_res_result_rows) + } + _ => ResResultBody::Void + } + } + CassandraResult::SetKeyspace( keyspace ) => { + ResResultBody::SetKeyspace(*keyspace.clone()) + } + CassandraResult::Prepared( prepared ) => { + ResResultBody::Prepared( *prepared.clone() ) + } + CassandraResult::SchemaChange(schema_change) => { + ResResultBody::SchemaChange( schema_change.clone() ) + } + CassandraResult::Void => { + ResResultBody::Void + } + }; + res_result_body.serialize(cursor); + } +} + +impl CassandraResult { + pub fn from_cursor( + cursor: &mut Cursor<&[u8]>, + version: Version, + ) -> Result { + + let res_result_body = ResResultBody::from_cursor(cursor, version)?; + Ok(match res_result_body { + ResResultBody::Void => CassandraResult::Void, + + ResResultBody::Rows( body_res_result_rows) => { + let mut value : Vec> = Vec::with_capacity(body_res_result_rows.rows_content.len()); + for row in &body_res_result_rows.rows_content { + let mut row_values = Vec::with_capacity( body_res_result_rows.metadata.col_specs.len()); + for (cbytes,colspec) in row.iter().zip( body_res_result_rows.metadata.col_specs.iter() ) { + row_values.push( MessageValue::build_value_from_cstar_col_type(colspec, cbytes) ); + } + value.push(row_values); + } + CassandraResult::Rows { + value : MessageValue::Rows(value), + metadata: body_res_result_rows.metadata.clone(), + } + }, + ResResultBody::SetKeyspace(keyspace) => { + CassandraResult::SetKeyspace( Box::new( keyspace.clone() )) + } + ResResultBody::Prepared(prepared) => { + CassandraResult::Prepared(Box::new( prepared.clone())) + } + ResResultBody::SchemaChange(schema_change) => { + CassandraResult::SchemaChange( schema_change.clone()) + } + }) + } +} + #[derive(PartialEq, Debug, Clone)] pub enum BatchStatementType { Statement(CQL), @@ -1034,3 +1129,47 @@ pub struct CassandraBatch { serial_consistency: Option, timestamp: Option, } + +#[cfg(test)] +mod test { + use crate::frame::CQL; + + #[test] + fn cql_round_trip_test() { + let query = r#"BEGIN BATCH + INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (1, 11, 'foo'); + INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (2, 12, 'bar'); + INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (3, 13, 'baz'); + APPLY BATCH;"#; + + let expected = "BEGIN BATCH INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (1, 11, 'foo'); INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (2, 12, 'bar'); INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (3, 13, 'baz'); APPLY BATCH"; + let cql = CQL::parse_from_string(query); + let result = cql.to_query_string(); + assert_eq!(expected, result) + } + + #[test] + fn cql_parse_multiple_test() { + let query = r#"INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (1, 11, 'foo'); + INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (2, 12, 'bar'); + INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (3, 13, 'baz');"#; + + let cql = CQL::parse_from_string(query); + assert_eq!(3, cql.get_statement_count()); + assert!(!cql.has_error); + } + + #[test] + fn cql_bad_statement_test() { + let query = r#"INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (1, 11, 'foo'); + INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) (2, 12, 'bar'); + INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (3, 13, 'baz');"#; + + let cql = CQL::parse_from_string(query); + assert_eq!(3, cql.get_statement_count()); + assert!(cql.has_error); + assert!(!cql.statements[0].has_error); + assert!(cql.statements[1].has_error); + assert!(!cql.statements[2].has_error); + } +} diff --git a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs index 6a960f7fe..7993f7142 100644 --- a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs @@ -1,3 +1,4 @@ +use crate::frame::cassandra::CQLStatement; use crate::frame::{CassandraOperation, CassandraResult, Frame}; use crate::message::{IntSize, Message, MessageValue}; use crate::{ @@ -10,7 +11,6 @@ use cql3_parser::cassandra_statement::CassandraStatement; use cql3_parser::select::SelectElement; use serde::Deserialize; use std::collections::HashMap; -use crate::frame::cassandra::CQLStatement; #[derive(Deserialize, Debug, Clone)] pub struct CassandraPeersRewriteConfig { diff --git a/shotover-proxy/src/transforms/mod.rs b/shotover-proxy/src/transforms/mod.rs index ebf68411c..a72d5756c 100644 --- a/shotover-proxy/src/transforms/mod.rs +++ b/shotover-proxy/src/transforms/mod.rs @@ -363,7 +363,10 @@ impl<'a> Wrapper<'a> { let transform_name = transform.get_name(); let chain_name = self.chain_name.clone(); - info!( "call_next_transform calling {} {}", transform_name, chain_name ); + info!( + "call_next_transform calling {} {}", + transform_name, chain_name + ); let start = Instant::now(); let result = CONTEXT_CHAIN_NAME diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index 90a9d2111..d5d133396 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -1,4 +1,5 @@ use crate::error::ChainResponse; +use crate::frame::cassandra::CQLStatement; use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; use crate::message::MessageValue; use crate::transforms::protect::key_management::{KeyManager, KeyManagerConfig}; @@ -17,7 +18,6 @@ use sodiumoxide::crypto::secretbox::{Key, Nonce}; use sodiumoxide::hex; use std::collections::HashMap; use tracing::warn; -use crate::frame::cassandra::CQLStatement; mod aws_kms; mod key_management; @@ -243,12 +243,17 @@ impl Transform for Protect { let statement = &mut cql_statement.statement; if let Some(table_name) = CQLStatement::get_table_name(statement) { - if let Some((_, tables)) = self.keyspace_table_columns.get_key_value(table_name) + if let Some((_, tables)) = + self.keyspace_table_columns.get_key_value(table_name) { if let Some((_, columns)) = tables.get_key_value(table_name) { - data_changed = - encrypt_columns(statement, columns, &self.key_source, &self.key_id) - .await?; + data_changed = encrypt_columns( + statement, + columns, + &self.key_source, + &self.key_id, + ) + .await?; } } } @@ -283,10 +288,10 @@ impl Transform for Protect { let statement = &mut cql_statement.statement; if let Some(table_name) = CQLStatement::get_table_name(statement) { if let Some((_keyspace, tables)) = - self.keyspace_table_columns.get_key_value(table_name) + self.keyspace_table_columns.get_key_value(table_name) { if let Some((_table, protect_columns)) = - tables.get_key_value(table_name) + tables.get_key_value(table_name) { if let CassandraStatement::Select(select) = &statement { let positions: Vec = select @@ -310,10 +315,15 @@ impl Transform for Protect { if let Some(v) = row.get_mut(*index) { if let MessageValue::Bytes(_) = v { let protected = - Protected::from_encrypted_bytes_value(v) - .await?; + Protected::from_encrypted_bytes_value( + v, + ) + .await?; let new_value: MessageValue = protected - .unprotect(&self.key_source, &self.key_id) + .unprotect( + &self.key_source, + &self.key_id, + ) .await?; *v = new_value; invalidate_cache = true; diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index 0628937e5..0fcf2e1bf 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -1,6 +1,8 @@ + use crate::config::topology::TopicHolder; use crate::error::ChainResponse; -use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame, RedisFrame}; +use crate::frame::cassandra::CQLStatement; +use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, CQL, Frame, RedisFrame}; use crate::message::{Message, Messages, QueryType}; use crate::transforms::chain::TransformChain; use crate::transforms::{ @@ -8,20 +10,37 @@ use crate::transforms::{ }; use anyhow::Result; use async_trait::async_trait; -use bytes::{BufMut, Bytes, BytesMut}; +use bytes::BytesMut; use cassandra_protocol::frame::Version; use cql3_parser::cassandra_statement::CassandraStatement; use cql3_parser::common::{Operand, RelationElement, RelationOperator}; -use itertools::Itertools; -use serde::Deserialize; +use serde::{Deserialize}; use std::collections::{BTreeMap, HashMap}; -use tracing_log::log::info; -use crate::frame::cassandra::CQLStatement; +use std::io::Cursor; +use tracing_log::log::{info, warn}; +use cassandra_protocol::frame::Serialize; +use itertools::Itertools; + +/* +Uses redis as a cache. Data is stored in Redis as a Hash. +The key for the Redis cache is the key for the Cassandra query. +Only exact Redis matches are supported, though the redis key can equate to a Cassandra scan. +If the key is requested again the result query result is returned from the cache. +If the key is deleted or updated the key is removed from the cache. +If the table is dropped the keys are removed from the cache. +The redis hash keys are: +data - serialized form of the row data from cassandra. +metadata - serialized form of the metadata from cassandra. + + */ enum CacheableState { - Read, - Update, - Delete, + /// string is the table name + Read(String), + /// string is the table name + Update(String), + /// string is the table name + Delete(String), /// string is the reason for the skip Skip(String), /// string is the reason for the error @@ -56,14 +75,131 @@ pub struct SimpleRedisCache { caching_schema: HashMap, } + impl SimpleRedisCache { fn get_name(&self) -> &'static str { "SimpleRedisCache" } - async fn get_or_update_from_cache( + /// Build the messages for the cache query from the cassandra request messages. + /// returns the Redis Messages or a `CacheableState:Err` or `CacheableState::Skip` as the error + fn build_cache_query(&mut self, + cassandra_messages: &mut Messages, ) -> Result { + let mut messages_redis_request = Vec::with_capacity(cassandra_messages.len()); + for cass_request in cassandra_messages { + match &mut cass_request.frame() { + Some(Frame::Cassandra(frame)) => { + for cql_statement in frame.operation.get_cql_statements() { + let mut state = is_cacheable(cql_statement); + if let CacheableState::Read(table_name) = &mut state { + let statement = &cql_statement.statement; + info!("build_cache_query processing cacheable state"); + if let Some(table_cache_schema) = + self.caching_schema.get(table_name.as_str()) + { + match build_redis_key_from_cql3( + statement, + table_cache_schema, + ) { + Ok(redis_key) => { + let commands_buffer = vec![ + RedisFrame::BulkString("GET".into()), + RedisFrame::BulkString(redis_key.into()), ]; + + messages_redis_request.push(Message::from_frame( + Frame::Redis(RedisFrame::Array(commands_buffer)), + )); + }, + Err(err_state) => {state = err_state;} + } + } else { + state = CacheableState::Skip( + format!("Table {} not in caching list", table_name) + ); + } + } else { + state = CacheableState::Skip(format!("{} is not a readable query",cql_statement)); + } + + match state { + CacheableState::Err(_) | + CacheableState::Skip(_) => { return Err(state) } + _ => {}, + } + } + } + _ => { + return Err(CacheableState::Err(format!( + "cannot fetch {cass_request:?} from cache" + ))) + } + } + } + Ok(messages_redis_request) + } + + /// unwraps redis response messages into cassandra messages. It does this by replacing the Cassandra + /// request messages with their corresponding Cassandra response messages and returns them. + /// Result is either the modified message_cass_request (now response messages) or and CacheableState::Err. + fn unwrap_cache_response(&self, messages_redis_response: Messages, mut cassandra_messages: Messages) -> Result { + // Replace cass_request messages with cassandra responses in place. + // We reuse the vec like this to save allocations. + let mut messages_redis_response_iter = messages_redis_response.into_iter(); + /* there is a redis response for each statement in a CassandraMessage so we have to map + the redis responses back to the cassandra requests + */ + + for cass_request in cassandra_messages.iter_mut() { + // the responses for this request + let cassandra_result : Result = if let Some(Frame::Cassandra(frame)) = &mut cass_request.frame() { + let queries = frame.operation.queries(); + if queries.len() != 1 { + Err(CacheableState::Err("Cacheable Cassandra query must be only one statement".into())) + } else { + if let Some(mut redis_response) = messages_redis_response_iter.next() { + match redis_response.frame() { + Some(Frame::Redis(RedisFrame::BulkString(redis_bytes))) => { + // Redis response contains serialized version of result struct from CassandraOperation::Result( result ) + let x = redis_bytes.iter().map(|y| *y).collect_vec(); + let mut cursor = Cursor::new(x.as_slice()); + let answer = CassandraResult::from_cursor(&mut cursor, Version::V4); + if let Ok(result) = answer { + Ok(result) + } else { + Err(CacheableState::Err(answer.err().unwrap().to_string())) + } + } + _ => Err(CacheableState::Err("No Redis frame in Redis response".into())) + } + } else { + Err(CacheableState::Err("Redis response was None".into())) + } + + } + } else { + Ok(CassandraResult::Void) + }; + if let Err(state) = cassandra_result { + return Err( state ); + } + + *cass_request = Message::from_frame(Frame::Cassandra(CassandraFrame { + version: Version::V4, + operation: CassandraOperation::Result(cassandra_result.ok().unwrap()), + stream_id: cass_request.stream_id().unwrap(), + tracing_id: None, + warnings: vec![], + })); + } + Ok(cassandra_messages) + } + + /// Reads the data from the cache for all the messages or None. + /// on success the values in the messages_cass_request will be modified to be Cassandra response messages. + /// return is the Cassandra response messages or an error containing CacheableState::Skip or CacheableState::Err. + async fn read_from_cache( &mut self, - mut messages_cass_request: Messages, + mut cassandra_messages: Messages, ) -> Result { // This function is a little hard to follow, so here's an overview. // We have 4 vecs of messages, each vec can be considered its own stage of processing. @@ -83,69 +219,13 @@ impl SimpleRedisCache { // + if the request is a CassandraOperation::Query then we consume a single message from messages_redis_response converting it to a cassandra response // * These are the cassandra responses that we return from the function. - info!("get_or_update_from_cache called"); - let mut messages_redis_request = Vec::with_capacity(messages_cass_request.len()); - for cass_request in &mut messages_cass_request { - match cass_request.frame() { - Some(Frame::Cassandra(frame)) => { - for statement in frame.operation.queries() { - let mut state = is_cacheable(statement); + info!("read_from_cache called"); - match state { - // TODO implement proper handling of state - // currently if state is not Skip or Error it just processes - CacheableState::Read - | CacheableState::Update - | CacheableState::Delete => { - info!("get_or_update_from_cache processing cacheable state"); - if let Some(table_name) = CQLStatement::get_table_name(statement) { - if let Some(table_cache_schema) = - self.caching_schema.get(table_name) - { - let redis_state = build_redis_ast_from_cql3( - statement, - table_cache_schema, - ); - if redis_state.is_ok() { - messages_redis_request.push(Message::from_frame( - Frame::Redis(redis_state.ok().unwrap()), - )); - } else { - state = redis_state.err().unwrap(); - } - } else { - state = CacheableState::Skip( - "Table not in caching list".into(), - ); - } - } else { - state = CacheableState::Skip("Table name not in query".into()); - } - } - _ => { - info!("get_or_update_from_cache not processing cacheable state"); - // do nothing here but check again again outside of match as state may have changed - } - } + // build the cache query + let messages_redis_request = self.build_cache_query(&mut cassandra_messages).ok().unwrap(); - if let CacheableState::Err(_) = state { - return Err(state); - } - - if let CacheableState::Skip(_) = state { - return Err(state); - } - } - } - _ => { - return Err(CacheableState::Err(format!( - "cannot fetch {cass_request:?} from cache" - ))); - } - } - } - - info!("get_or_update_from_cache calling cache_chain.process_request"); + // execute the cache query + info!("read_from_cache calling cache_chain.process_request"); match self .cache_chain @@ -156,117 +236,133 @@ impl SimpleRedisCache { .await { Ok(messages_redis_response) => { - info!("get_or_update_from_cache received OK from cache_chain.process_request"); - // Replace cass_request messages with cassandra responses in place. - // We reuse the vec like this to save allocations. - let mut messages_redis_response_iter = messages_redis_response.into_iter(); - for cass_request in &mut messages_cass_request { - let mut redis_responses = vec![]; - if let Some(Frame::Cassandra(frame)) = cass_request.frame() { - for _query in frame.operation.queries() { - redis_responses.push(messages_redis_response_iter.next()); - } - } - - // TODO: Translate the redis_responses into a cassandra result - *cass_request = Message::from_frame(Frame::Cassandra(CassandraFrame { - version: Version::V4, - operation: CassandraOperation::Result(CassandraResult::Void), - stream_id: cass_request.stream_id().unwrap(), - tracing_id: None, - warnings: vec![], - })); - } - Ok(messages_cass_request) + info!("read_from_cache received OK from cache_chain.process_request"); + self.unwrap_cache_response(messages_redis_response, cassandra_messages) } Err(e) => Err(CacheableState::Err(format!("Redis error: {}", e))), } } -} -fn append_prefix_min(min: &mut Vec) { - if min.is_empty() { - min.push(b'['); - } else { - min.push(b':'); + /// clear the cache for the entry. + fn clear_table_cache(&mut self, cql_statement: &CQLStatement, table_cache_schema: &TableCacheSchema) -> Option { + // TODO is it possible to return the future and process in parallel? + let statement = &cql_statement.statement; + if let Ok(redis_key) = build_redis_key_from_cql3(statement, table_cache_schema) { + let commands_buffer: Vec = vec![ + RedisFrame::BulkString("DEL".into()), + RedisFrame::BulkString(redis_key.into()), + ]; + Some(Message::from_frame(Frame::Redis(RedisFrame::Array(commands_buffer)))) + } else { + None + } } -} -fn append_prefix_max(max: &mut Vec) { - if max.is_empty() { - max.push(b']'); - } else { - max.push(b':'); - } -} -fn build_zrangebylex_min_max_from_sql( - operator: &RelationOperator, - operand: &Operand, - min: &mut Vec, - max: &mut Vec, -) -> Result<(), CacheableState> { - let mut bytes = BytesMut::from(operand.to_string().as_bytes()); - match operator { - RelationOperator::LessThan => { - let last_byte = bytes.last_mut().unwrap(); - *last_byte -= 1; - - append_prefix_max(max); - max.extend(bytes.iter()); - Ok(()) - } - RelationOperator::LessThanOrEqual => { - append_prefix_max(max); - max.extend(bytes.iter()); - Ok(()) - } + /// calls the next transform and process the result for caching. + async fn execute_upstream_and_process_result<'a>(&mut self, message_wrapper: Wrapper<'a> + ) -> ChainResponse { + let mut orig_messages = message_wrapper.messages.clone(); + let orig_cql : Option<&mut CQL> = orig_messages.iter_mut() + .filter_map( |message| { + if let Some(Frame::Cassandra(CassandraFrame { operation: CassandraOperation::Query{query,..}, .. })) = message.frame() + { + Some(query) + } else { + None + } - RelationOperator::Equal => { - append_prefix_min(min); - append_prefix_max(max); - min.extend(bytes.iter()); - max.extend(bytes.iter()); - Ok(()) - } - RelationOperator::GreaterThanOrEqual => { - append_prefix_min(min); - min.extend(bytes.iter()); - Ok(()) - } - RelationOperator::GreaterThan => { - let last_byte = bytes.last_mut().unwrap(); - *last_byte += 1; - append_prefix_min(min); - min.extend(bytes.iter()); - Ok(()) + } ).next(); + let result_messages = &mut message_wrapper.call_next_transform().await?; + if orig_cql.is_some() { + let mut cache_messages: Vec = vec!(); + for (response, cql_statement) in result_messages.iter_mut().zip(orig_cql.unwrap().statements.iter()) { + if let Some(Frame::Cassandra(CassandraFrame { operation: CassandraOperation::Result(result), .. })) = response.frame() { + match is_cacheable(cql_statement) { + CacheableState::Update(table_name) | + CacheableState::Delete(table_name) => { + if let Some(table_cache_schema) = self.caching_schema.get(&table_name ) + { + let table_schema = table_cache_schema.clone(); + if let Some(fut_message) = self.clear_table_cache(cql_statement, &table_schema) { + cache_messages.push(fut_message); + } + } else { + info!( "table {} is not being cached", table_name ); + } + } + CacheableState::Read(table_name) => { + let statement = &cql_statement.statement; + if let Some(table_cache_schema) = self.caching_schema.get(table_name.as_str()) + { + if let Ok(redis_key) = build_redis_key_from_cql3( + statement, + table_cache_schema, + ) { + let mut encoded: Vec = Vec::new(); + let mut cursor = Cursor::new(&mut encoded); + result.serialize(&mut cursor); + + let commands_buffer: Vec = vec![ + RedisFrame::BulkString("SET".into()), + RedisFrame::BulkString(redis_key.into()), + RedisFrame::BulkString(encoded.into()), + ]; + + cache_messages.push( Message::from_frame(Frame::Redis(RedisFrame::Array(commands_buffer)))); + + } + } + } + CacheableState::Skip(_reason) | + CacheableState::Err(_reason) => { + // do nothing + } + } + } + } + if ! cache_messages.is_empty() { + let result = self + .cache_chain + .process_request( + Wrapper::new_with_chain_name(cache_messages, self.cache_chain.name.clone()), + "clientdetailstodo".to_string(), + ).await; + if result.is_err() { + warn!( "Cache error: {}", result.err().unwrap()); + } + } } - // should "IN"" be converted to an "or" "eq" combination - RelationOperator::NotEqual - | RelationOperator::In - | RelationOperator::Contains - | RelationOperator::ContainsKey => Err(CacheableState::Skip(format!( - "{} comparisons are not supported", - operator - ))), - RelationOperator::IsNot => Err(CacheableState::Skip( - "IS NOT NULL comparisons are not supported".into(), - )), + Ok(result_messages.to_vec()) } } -fn is_cacheable(statement: &CassandraStatement) -> CacheableState { - let has_params = CQLStatement::has_params(statement); - match statement { - CassandraStatement::Select(select) => { - if has_params { - CacheableState::Delete - } else if select.filtering { - CacheableState::Skip("Can not cache with ALLOW FILTERING".into()) - } else if select.where_clause.is_empty() { - CacheableState::Skip("Can not cache if where clause is empty".into()) - /* } else if !select.columns.is_empty() { +/// Determines if a statement is cacheable. Cacheable statements have several common +/// properties as well as operation specific properties. +/// Common properties include +/// * must specify table name +/// * must not contain a parsing error +/// * +fn is_cacheable(cql_statement: &CQLStatement) -> CacheableState { + + // check issues common to all cql_statements + if cql_statement.has_error { + return CacheableState::Skip( "CQL statement has error".into()); + } + if let Some(table_name) = CQLStatement::get_table_name(&cql_statement.statement) + { + let has_params = CQLStatement::has_params(&cql_statement.statement); + + match &cql_statement.statement { + CassandraStatement::Select(select) => { + if has_params { + CacheableState::Delete(table_name.into()) + } else if select.filtering { + CacheableState::Skip("Can not cache with ALLOW FILTERING".into()) + } else if select.where_clause.is_empty() { + CacheableState::Skip("Can not cache if where clause is empty".into()) + /* } else if !select.columns.is_empty() { if select.columns.len() == 1 && select.columns[0].eq(&SelectElement::Star) { CacheableState::Read } else { @@ -276,156 +372,140 @@ fn is_cacheable(statement: &CassandraStatement) -> CacheableState { } */ - } else { - CacheableState::Read + } else { + CacheableState::Read(table_name.into()) + } } - } - CassandraStatement::Insert(insert) => { - if has_params || insert.if_not_exists { - CacheableState::Delete - } else { - CacheableState::Update + CassandraStatement::Insert(insert) => { + if has_params || insert.if_not_exists { + CacheableState::Delete(table_name.into()) + } else { + CacheableState::Update(table_name.into()) + } } - } - CassandraStatement::Update(update) => { - if has_params || update.if_exists { - CacheableState::Delete - } else { - for assignment_element in &update.assignments { - if assignment_element.operator.is_some() { - info!( + CassandraStatement::Update(update) => { + if has_params || update.if_exists { + CacheableState::Delete(table_name.into()) + } else { + for assignment_element in &update.assignments { + if assignment_element.operator.is_some() { + info!( "Clearing {} cache: {} has calculations in values", update.table_name, assignment_element.name ); - return CacheableState::Delete; - } - if assignment_element.name.idx.is_some() { - info!( + return CacheableState::Delete(table_name.into()); + } + if assignment_element.name.idx.is_some() { + info!( "Clearing {} cache: {} is an indexed columns", update.table_name, assignment_element.name ); - return CacheableState::Delete; + return CacheableState::Delete(table_name.into()); + } } + CacheableState::Update(table_name.into()) } - CacheableState::Update } - } - _ => CacheableState::Skip("Statement is not a cacheable type".into()), + _ => CacheableState::Skip("Statement is not a cacheable type".into()), + } + } else { + CacheableState::Skip("No table name specified".into()) } } -fn build_redis_ast_from_cql3( - statement: &CassandraStatement, +/// build the redis key for the query. +/// key is cassandra partition key (must be completely specified) prepended to +/// the cassandra range key (may be partially specified) +fn build_query_redis_key_from_value_map( table_cache_schema: &TableCacheSchema, -) -> Result { - match statement { - CassandraStatement::Select(select) => { - let mut min: Vec = Vec::new(); - let mut max: Vec = Vec::new(); - - // extract the partition and range operands - // fail if any are missing - let mut partition_segments: HashMap<&str, &Operand> = HashMap::new(); - let mut range_segments: HashMap<&str, Vec<&RelationElement>> = HashMap::new(); - - for relation_element in &select.where_clause { - if let Operand::Column(column_name) = &relation_element.obj { - // name has to be in partition or range key. - if table_cache_schema.partition_key.contains(column_name) { - partition_segments.insert(column_name, &relation_element.value); - } else if table_cache_schema.range_key.contains(column_name) { - let value = range_segments.get_mut(column_name.as_str()); - if let Some(vec) = value { - vec.push(relation_element) - } else { - range_segments.insert(column_name, vec![relation_element]); - }; - } else { - return Err(CacheableState::Skip(format!( - "Couldn't build query - column {} is not in the key", - column_name - ))); + query_values: &BTreeMap>, +) -> Result { + + let mut redis_key = BytesMut::new(); + + for c_name in &table_cache_schema.partition_key { + let column_name = c_name.to_lowercase(); + match query_values.get( column_name.as_str() ) { + None => { return Err( CacheableState::Skip( format!("Partition key not complete. missing segment {}", column_name )));}, + Some( relation_elements ) => { + if relation_elements.len() > 1 { + return Err( CacheableState::Skip( format!("partition key segment {} has more than one relationship", column_name))); } + redis_key.extend( relation_elements[0].value.to_string().as_bytes() ); } } - let mut skipping = false; - for column_name in &table_cache_schema.range_key { - if let Some(relation_elements) = range_segments.get(column_name.as_str()) { + } + let mut skipping = false; + + for c_name in &table_cache_schema.range_key { + let column_name = c_name.to_lowercase(); + match query_values.get( column_name.as_str() ) { + None => { skipping = true; }, + Some( relation_elements ) => { if skipping { // we skipped an earlier column so this is an error. return Err(CacheableState::Err( "Columns in the middle of the range key were skipped".into(), )); } - for range_element in relation_elements { - if let Err(e) = build_zrangebylex_min_max_from_sql( - &range_element.oper, - &range_element.value, - &mut min, - &mut max, - ) { - return Err(e); + let mut my_elements = relation_elements.clone(); + + my_elements.sort(); + for relation_element in my_elements { + redis_key.extend( relation_element.to_string().as_bytes() ); } - } - } else { - // once we skip a range key column we have to skip all the rest so set a flag. - skipping = true; - } - } - let min = if min.is_empty() { - Bytes::from_static(b"-") - } else { - Bytes::from(min) - }; - let max = if max.is_empty() { - Bytes::from_static(b"+") - } else { - Bytes::from(max) - }; - let mut partition_key = BytesMut::new(); - for column_name in &table_cache_schema.partition_key { - if let Some(operand) = partition_segments.get(column_name.as_str()) { - partition_key.extend(operand.to_string().as_bytes()); - } else { - return Err(CacheableState::Err(format!( - "partition column {} missing", - column_name - ))); } } + } + Ok(redis_key) +} - let commands_buffer = vec![ - RedisFrame::BulkString("ZRANGEBYLEX".into()), - RedisFrame::BulkString(partition_key.freeze()), - RedisFrame::BulkString(min), - RedisFrame::BulkString(max), - ]; - Ok(RedisFrame::Array(commands_buffer)) + +fn populate_value_map_from_where_clause(value_map : &mut BTreeMap>, where_clause : &Vec) { + for relation_element in where_clause { + let column_name = relation_element.obj.to_string().to_lowercase(); + let value = value_map.get_mut(column_name.as_str()); + if let Some(vec) = value { + vec.push(relation_element.clone()) + } else { + value_map.insert(column_name, vec![relation_element.clone()]); + }; + } +} + +fn build_redis_key_from_cql3( + statement: &CassandraStatement, + table_cache_schema: &TableCacheSchema, +) -> Result { + let mut value_map : BTreeMap> = BTreeMap::new(); + match statement { + CassandraStatement::Select(select) => { + populate_value_map_from_where_clause( &mut value_map, &select.where_clause ); + build_query_redis_key_from_value_map(table_cache_schema, &value_map) } + CassandraStatement::Insert(insert) => { - let query_values = insert.get_value_map(); - add_query_values(table_cache_schema, query_values) + for (c_name, operand) in insert.get_value_map().into_iter() { + let column_name = c_name.to_lowercase(); + let relation_element = RelationElement{ + obj: Operand::Column( column_name.clone()), + oper: RelationOperator::Equal, + value: operand.clone(), + }; + let value = value_map.get_mut(column_name.as_str()); + if let Some(vec) = value { + vec.push( relation_element) + } else { + value_map.insert(column_name, vec![relation_element]); + }; + } + build_query_redis_key_from_value_map(table_cache_schema, &value_map) } CassandraStatement::Update(update) => { - let mut query_values: BTreeMap = BTreeMap::new(); - - update.assignments.iter().for_each(|assignment| { - query_values.insert(assignment.name.to_string(), &assignment.value); - }); - for relation_element in &update.where_clause { - if relation_element.oper == RelationOperator::Equal { - if let Operand::Column(name) = &relation_element.obj { - if table_cache_schema.partition_key.contains(name) - || table_cache_schema.range_key.contains(name) - { - query_values.insert(name.clone(), &relation_element.value); - } - } - } - } - add_query_values(table_cache_schema, query_values) + populate_value_map_from_where_clause( &mut value_map, &update.where_clause); + build_query_redis_key_from_value_map(table_cache_schema, &value_map) } _ => unreachable!( "{} should not be passed to build_redis_ast_from_cql3", @@ -434,113 +514,27 @@ fn build_redis_ast_from_cql3( } } -fn add_query_values( - table_cache_schema: &TableCacheSchema, - query_values: BTreeMap, -) -> Result { - let mut partition_key = BytesMut::new(); - for column_name in &table_cache_schema.partition_key { - if let Some(operand) = query_values.get(column_name) { - partition_key.extend(operand.to_string().as_bytes()); - } else { - return Err(CacheableState::Err(format!( - "partition column {} missing", - column_name - ))); - } - } - - let mut clustering = BytesMut::new(); - for column_name in &table_cache_schema.range_key { - if let Some(operand) = query_values.get(column_name.as_str()) { - clustering.extend(operand.to_string().as_bytes()); - } else { - return Err(CacheableState::Err(format!( - "range column {} missing", - column_name - ))); - } - } - - let mut commands_buffer: Vec = vec![ - RedisFrame::BulkString("ZADD".into()), - RedisFrame::BulkString(partition_key.freeze()), - ]; - - // get values not in partition or cluster key - let values = query_values - .iter() - .filter_map(|(column_name, value)| { - if !table_cache_schema - .partition_key - .contains(&column_name.to_string()) - && !table_cache_schema - .range_key - .contains(&column_name.to_string()) - { - Some(value) - } else { - None - } - }) - .collect_vec(); - - for operand in values { - commands_buffer.push(RedisFrame::BulkString(Bytes::from_static(b"0"))); - let mut value = clustering.clone(); - if !value.is_empty() { - value.put_u8(b':'); - } - value.extend(operand.to_string().as_bytes()); - commands_buffer.push(RedisFrame::BulkString(value.freeze())); - } - - Ok(RedisFrame::Array(commands_buffer)) -} #[async_trait] impl Transform for SimpleRedisCache { - async fn transform<'a>(&'a mut self, message_wrapper: Wrapper<'a>) -> ChainResponse { + async fn transform<'a>(&'a mut self, mut message_wrapper: Wrapper<'a>) -> ChainResponse { let mut read_cache = true; - for m in &message_wrapper.messages { - if let Some(Frame::Cassandra(CassandraFrame{ operation : CassandraOperation::Query{ query , ..},..})) = m.frame() { - - - - - //if let Some(&mut Frame::Cassandra(CassandraFrame { - // operation: &CassandraOperation::Query { query, .. }, - // .. - // })) = &mut m.frame() - // { - /* let statement = query.get_statement(); - match is_cacheable(statement ) { - CacheableState::Read | - CacheableState::Update | - CacheableState::Delete => {} - CacheableState::Skip(reason) => { - tracing::info!( "Cache skipped for {} due to {}", statement, reason ); - use_cache = false; - } - CacheableState::Err(reason) => { - tracing::error!("Cache failed for {} due to {}", statement, reason); - use_cache = false; - } - } - - */ - for cql_statement in &query.statements { - info!("cache transform processing {}", cql_statement); - match cql_statement.get_query_type() { - QueryType::Read => {} - QueryType::Write => read_cache = false, - QueryType::ReadWrite => read_cache = false, - QueryType::SchemaChange => read_cache = false, - QueryType::PubSubMessage => {} - } + for m in &mut message_wrapper.messages { + if let Some(Frame::Cassandra(CassandraFrame { + operation: CassandraOperation::Query { query, .. }, + .. + })) = m.frame() + { + for cql_statement in &query.statements { + info!("cache transform processing {}", cql_statement); + match cql_statement.get_query_type() { + QueryType::Read => {} + QueryType::Write => read_cache = false, + QueryType::ReadWrite => read_cache = false, + QueryType::SchemaChange => read_cache = false, + QueryType::PubSubMessage => {} } - - + } } else { read_cache = false; } @@ -550,33 +544,30 @@ impl Transform for SimpleRedisCache { // If there are no write queries (all queries are reads) we can read the cache if read_cache { match self - .get_or_update_from_cache(message_wrapper.messages.clone()) + .read_from_cache(message_wrapper.messages.clone()) .await { Ok(cr) => return Ok(cr), Err(inner_state) => match &inner_state { - CacheableState::Read | CacheableState::Update | CacheableState::Delete => { - unreachable!("should not find read, update or delete as an error"); - } CacheableState::Skip(reason) => { tracing::info!("Cache skipped: {} ", reason); - message_wrapper.call_next_transform().await + self.execute_upstream_and_process_result( message_wrapper ).await } CacheableState::Err(reason) => { tracing::error!("Cache failed: {} ", reason); message_wrapper.call_next_transform().await } + _ => { + unreachable!("should not find read, update or delete as an error."); + } }, } } else { - let (_cache_res, upstream) = tokio::join!( - self.get_or_update_from_cache(message_wrapper.messages.clone()), - message_wrapper.call_next_transform() - ); - upstream + self.execute_upstream_and_process_result( message_wrapper ).await } } + fn validate(&self) -> Vec { let mut errors = self .cache_chain @@ -610,7 +601,7 @@ mod test { fn build_query(query_string: &str) -> CassandraStatement { let cql = CQL::parse_from_string(query_string); assert!(!cql.has_error); - cql.get_statement().clone() + cql.statements[0].statement.clone() } #[test] diff --git a/shotover-proxy/tests/helpers/cassandra.rs b/shotover-proxy/tests/helpers/cassandra.rs index df4a0fc91..5767d5df9 100644 --- a/shotover-proxy/tests/helpers/cassandra.rs +++ b/shotover-proxy/tests/helpers/cassandra.rs @@ -1,7 +1,7 @@ use cassandra_cpp::{stmt, Cluster, Error, Session, Value, ValueType}; use ordered_float::OrderedFloat; -use tracing::info; use test_helpers::try_wait_for_socket_to_open; +use tracing::info; pub fn cassandra_connection(contact_points: &str, port: u16) -> Session { for contact_point in contact_points.split(',') { @@ -109,7 +109,7 @@ impl ResultValue { #[allow(unused)] pub fn execute_query(session: &Session, query: &str) -> Vec> { let statement = stmt!(query); - info!( "executing query: {}", query); + info!("executing query: {}", query); match session.execute(&statement).wait() { Ok(result) => result .into_iter() From 5d79b450314d6c876457f88c9d0cb76032d3b9ec Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Thu, 14 Apr 2022 15:31:45 +0100 Subject: [PATCH 21/60] updated caching tests --- shotover-proxy/src/frame/cassandra.rs | 13 +- .../src/transforms/cassandra/peers_rewrite.rs | 2 +- shotover-proxy/src/transforms/redis/cache.rs | 272 +++++++++--------- 3 files changed, 142 insertions(+), 145 deletions(-) diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index 3dcd4c4e1..5c8587bee 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -527,10 +527,7 @@ impl CQLStatement { } pub fn is_apply_batch(&self) -> bool { - match &self.statement { - CassandraStatement::ApplyBatch => true, - _ => false, - } + matches!(&self.statement, CassandraStatement::ApplyBatch) } /// returns the query type for the current statement. @@ -1093,17 +1090,17 @@ impl CassandraResult { } CassandraResult::Rows { value : MessageValue::Rows(value), - metadata: body_res_result_rows.metadata.clone(), + metadata: body_res_result_rows.metadata, } }, ResResultBody::SetKeyspace(keyspace) => { - CassandraResult::SetKeyspace( Box::new( keyspace.clone() )) + CassandraResult::SetKeyspace( Box::new( keyspace )) } ResResultBody::Prepared(prepared) => { - CassandraResult::Prepared(Box::new( prepared.clone())) + CassandraResult::Prepared(Box::new( prepared)) } ResResultBody::SchemaChange(schema_change) => { - CassandraResult::SchemaChange( schema_change.clone()) + CassandraResult::SchemaChange( schema_change ) } }) } diff --git a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs index 7993f7142..8fb9e8225 100644 --- a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs @@ -76,7 +76,7 @@ fn extract_native_port_column(message: &mut Message) -> Vec { for cql_statement in &query.statements { let statement = &cql_statement.statement; if let CassandraStatement::Select(select) = &statement { - if let Some(table_name) = CQLStatement::get_table_name(&statement) { + if let Some(table_name) = CQLStatement::get_table_name(statement) { if table_name.eq("system.peers_v2") { select .columns diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index 0fcf2e1bf..e61e0dc43 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -19,6 +19,7 @@ use std::collections::{BTreeMap, HashMap}; use std::io::Cursor; use tracing_log::log::{info, warn}; use cassandra_protocol::frame::Serialize; +use cql3_parser::select::{Named, Select, SelectElement}; use itertools::Itertools; /* @@ -101,10 +102,11 @@ impl SimpleRedisCache { statement, table_cache_schema, ) { - Ok(redis_key) => { + Ok((redis_key,hash_key)) => { let commands_buffer = vec![ - RedisFrame::BulkString("GET".into()), - RedisFrame::BulkString(redis_key.into()), ]; + RedisFrame::BulkString("HGET".into()), + RedisFrame::BulkString(redis_key.into()), + RedisFrame::BulkString(hash_key.into()), ]; messages_redis_request.push(Message::from_frame( Frame::Redis(RedisFrame::Array(commands_buffer)), @@ -155,26 +157,23 @@ impl SimpleRedisCache { let queries = frame.operation.queries(); if queries.len() != 1 { Err(CacheableState::Err("Cacheable Cassandra query must be only one statement".into())) - } else { - if let Some(mut redis_response) = messages_redis_response_iter.next() { - match redis_response.frame() { - Some(Frame::Redis(RedisFrame::BulkString(redis_bytes))) => { - // Redis response contains serialized version of result struct from CassandraOperation::Result( result ) - let x = redis_bytes.iter().map(|y| *y).collect_vec(); - let mut cursor = Cursor::new(x.as_slice()); - let answer = CassandraResult::from_cursor(&mut cursor, Version::V4); - if let Ok(result) = answer { - Ok(result) - } else { - Err(CacheableState::Err(answer.err().unwrap().to_string())) - } + } else if let Some(mut redis_response) = messages_redis_response_iter.next() { + match redis_response.frame() { + Some(Frame::Redis(RedisFrame::BulkString(redis_bytes))) => { + // Redis response contains serialized version of result struct from CassandraOperation::Result( result ) + let x = redis_bytes.iter().copied().collect_vec(); + let mut cursor = Cursor::new(x.as_slice()); + let answer = CassandraResult::from_cursor(&mut cursor, Version::V4); + if let Ok(result) = answer { + Ok(result) + } else { + Err(CacheableState::Err(answer.err().unwrap().to_string())) } - _ => Err(CacheableState::Err("No Redis frame in Redis response".into())) } - } else { - Err(CacheableState::Err("Redis response was None".into())) + _ => Err(CacheableState::Err("No Redis frame in Redis response".into())) } - + } else { + Err(CacheableState::Err("Redis response was None".into())) } } else { Ok(CassandraResult::Void) @@ -247,9 +246,9 @@ impl SimpleRedisCache { fn clear_table_cache(&mut self, cql_statement: &CQLStatement, table_cache_schema: &TableCacheSchema) -> Option { // TODO is it possible to return the future and process in parallel? let statement = &cql_statement.statement; - if let Ok(redis_key) = build_redis_key_from_cql3(statement, table_cache_schema) { + if let Ok((redis_key,_hash_key)) = build_redis_key_from_cql3(statement, table_cache_schema) { let commands_buffer: Vec = vec![ - RedisFrame::BulkString("DEL".into()), + RedisFrame::BulkString("GETDEL".into()), RedisFrame::BulkString(redis_key.into()), ]; Some(Message::from_frame(Frame::Redis(RedisFrame::Array(commands_buffer)))) @@ -295,7 +294,7 @@ impl SimpleRedisCache { let statement = &cql_statement.statement; if let Some(table_cache_schema) = self.caching_schema.get(table_name.as_str()) { - if let Ok(redis_key) = build_redis_key_from_cql3( + if let Ok((redis_key,hash_key)) = build_redis_key_from_cql3( statement, table_cache_schema, ) { @@ -304,8 +303,9 @@ impl SimpleRedisCache { result.serialize(&mut cursor); let commands_buffer: Vec = vec![ - RedisFrame::BulkString("SET".into()), + RedisFrame::BulkString("HSET".into()), RedisFrame::BulkString(redis_key.into()), + RedisFrame::BulkString(hash_key.into()), RedisFrame::BulkString(encoded.into()), ]; @@ -421,9 +421,7 @@ fn build_query_redis_key_from_value_map( table_cache_schema: &TableCacheSchema, query_values: &BTreeMap>, ) -> Result { - - let mut redis_key = BytesMut::new(); - + let mut key : Vec = vec!(); for c_name in &table_cache_schema.partition_key { let column_name = c_name.to_lowercase(); match query_values.get( column_name.as_str() ) { @@ -432,7 +430,10 @@ fn build_query_redis_key_from_value_map( if relation_elements.len() > 1 { return Err( CacheableState::Skip( format!("partition key segment {} has more than one relationship", column_name))); } - redis_key.extend( relation_elements[0].value.to_string().as_bytes() ); + if !key.is_empty() { + key.push( b':' ) + } + key.extend(relation_elements[0].value.to_string().as_bytes() ); } } } @@ -449,21 +450,76 @@ fn build_query_redis_key_from_value_map( "Columns in the middle of the range key were skipped".into(), )); } + if relation_elements.len() > 1 { + return Err( CacheableState::Skip( format!("partition key segment {} has more than one relationship", column_name))); + } + if !key.is_empty() { + key.push( b':' ) + } + key.extend(relation_elements[0].value.to_string().as_bytes() ); + /* let mut my_elements = relation_elements.clone(); my_elements.sort(); for relation_element in my_elements { - redis_key.extend( relation_element.to_string().as_bytes() ); + if !redis_key.is_empty() { + redis_key.extend(':'); + } + redis_key.extend( relation_element.value.to_string().as_bytes() ); } - +*/ } } } - Ok(redis_key) + Ok(BytesMut::from( key.as_slice() )) } +/// build the redis key for the query. +/// key is cassandra partition key (must be completely specified) prepended to +/// the cassandra range key (may be partially specified) +fn build_query_redis_hash_from_value_map( + table_cache_schema: &TableCacheSchema, + query_values: &BTreeMap>, + select : &Select +) -> Result { -fn populate_value_map_from_where_clause(value_map : &mut BTreeMap>, where_clause : &Vec) { + let mut my_values = query_values.clone(); + for c_name in &table_cache_schema.partition_key { + let column_name = c_name.to_lowercase(); + my_values.remove(&column_name); + } + for c_name in &table_cache_schema.range_key { + let column_name = c_name.to_lowercase(); + my_values.remove(&column_name); + } + + let mut str = if select.columns.is_empty() { + String::from( "WHERE ") + } else { + let mut tmp = select.columns.iter().map(|select_element| { + match select_element { + SelectElement::Star => { SelectElement::Star } + SelectElement::Column(named) => { SelectElement::Column(Named { name: named.name.to_lowercase(), alias: match &named.alias{ + None => None, + Some(name) => Some(name.to_lowercase()), + } }) }, + SelectElement::Function(named) => {SelectElement::Function(Named { name: named.name.to_lowercase(), alias: match &named.alias{ + None => None, + Some(name) => Some(name.to_lowercase()), + } })} + } + }).join( ", "); + tmp.push_str( " WHERE "); + tmp + }; + str.push_str( my_values.iter_mut().sorted() + .flat_map( |(_k,v)| v.iter() ).join(" AND ").as_str()); + + Ok(BytesMut::from( str.as_str() )) +} + + +fn populate_value_map_from_where_clause(value_map : &mut BTreeMap>, where_clause : &[RelationElement]) { for relation_element in where_clause { let column_name = relation_element.obj.to_string().to_lowercase(); let value = value_map.get_mut(column_name.as_str()); @@ -478,12 +534,13 @@ fn populate_value_map_from_where_clause(value_map : &mut BTreeMap Result { +) -> Result<(BytesMut,BytesMut), CacheableState> { let mut value_map : BTreeMap> = BTreeMap::new(); match statement { CassandraStatement::Select(select) => { populate_value_map_from_where_clause( &mut value_map, &select.where_clause ); - build_query_redis_key_from_value_map(table_cache_schema, &value_map) + Ok((build_query_redis_key_from_value_map(table_cache_schema, &value_map)?, + build_query_redis_hash_from_value_map(table_cache_schema, &value_map, &select)?)) } CassandraStatement::Insert(insert) => { @@ -501,14 +558,16 @@ fn build_redis_key_from_cql3( value_map.insert(column_name, vec![relation_element]); }; } - build_query_redis_key_from_value_map(table_cache_schema, &value_map) + Ok((build_query_redis_key_from_value_map(table_cache_schema, &value_map)?, + BytesMut::new() )) } CassandraStatement::Update(update) => { populate_value_map_from_where_clause( &mut value_map, &update.where_clause); - build_query_redis_key_from_value_map(table_cache_schema, &value_map) + Ok((build_query_redis_key_from_value_map(table_cache_schema, &value_map)?, + BytesMut::new() )) } _ => unreachable!( - "{} should not be passed to build_redis_ast_from_cql3", + "{} should not be passed to build_redis_key_from_cql3", statement ), } @@ -590,11 +649,9 @@ mod test { use crate::transforms::chain::TransformChain; use crate::transforms::debug::printer::DebugPrinter; use crate::transforms::null::Null; - use crate::transforms::redis::cache::{ - build_redis_ast_from_cql3, SimpleRedisCache, TableCacheSchema, - }; + use crate::transforms::redis::cache::{build_query_redis_hash_from_value_map, build_redis_key_from_cql3, SimpleRedisCache, TableCacheSchema}; use crate::transforms::{Transform, Transforms}; - use bytes::Bytes; + use bytes::{Bytes, BytesMut}; use cql3_parser::cassandra_statement::CassandraStatement; use std::collections::HashMap; @@ -613,18 +670,12 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x = 123 AND y = 965"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema) + let (redis_key, hash_key) = build_redis_key_from_cql3(&ast, &table_cache_schema) .ok() .unwrap(); - let expected = RedisFrame::Array(vec![ - RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), - RedisFrame::BulkString(Bytes::from_static(b"1")), - RedisFrame::BulkString(Bytes::from_static(b"[123:965")), - RedisFrame::BulkString(Bytes::from_static(b"]123:965")), - ]); - - assert_eq!(expected, query); + assert_eq!( BytesMut::from( "1:123:965"), redis_key); + assert_eq!( BytesMut::from( "* WHERE "), hash_key ); } #[test] @@ -636,18 +687,12 @@ mod test { let ast = build_query("INSERT INTO foo (z, v) VALUES (1, 123)"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema) + let (redis_key, hash_key) = build_redis_key_from_cql3(&ast, &table_cache_schema) .ok() .unwrap(); - let expected = RedisFrame::Array(vec![ - RedisFrame::BulkString(Bytes::from_static(b"ZADD")), - RedisFrame::BulkString(Bytes::from_static(b"1")), - RedisFrame::BulkString(Bytes::from_static(b"0")), - RedisFrame::BulkString(Bytes::from_static(b"123")), - ]); - - assert_eq!(expected, query); + assert_eq!( BytesMut::from( "1"), redis_key); + assert!( hash_key.is_empty()); } #[test] @@ -658,40 +703,28 @@ mod test { }; let ast = build_query("INSERT INTO foo (z, c, v) VALUES (1, 'yo' , 123)"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema) + let (redis_key, hash_key) = build_redis_key_from_cql3(&ast, &table_cache_schema) .ok() .unwrap(); - let expected = RedisFrame::Array(vec![ - RedisFrame::BulkString(Bytes::from_static(b"ZADD")), - RedisFrame::BulkString(Bytes::from_static(b"1")), - RedisFrame::BulkString(Bytes::from_static(b"0")), - RedisFrame::BulkString(Bytes::from_static(b"'yo':123")), - ]); - - assert_eq!(expected, query); + assert_eq!(BytesMut::from("1:'yo'"), redis_key); + assert!( hash_key.is_empty() ); } #[test] fn update_simple_clustering_test() { let table_cache_schema = TableCacheSchema { partition_key: vec!["z".to_string()], - range_key: vec!["c".to_string()], + range_key: vec![], }; let ast = build_query("UPDATE foo SET c = 'yo', v = 123 WHERE z = 1"); - let result = build_redis_ast_from_cql3(&ast, &table_cache_schema); - let query = result.ok().unwrap(); - - let expected = RedisFrame::Array(vec![ - RedisFrame::BulkString(Bytes::from_static(b"ZADD")), - RedisFrame::BulkString(Bytes::from_static(b"1")), - RedisFrame::BulkString(Bytes::from_static(b"0")), - RedisFrame::BulkString(Bytes::from_static(b"'yo':123")), - ]); + let result = build_redis_key_from_cql3(&ast, &table_cache_schema); + let (redis_key, hash_key) = result.ok().unwrap(); - assert_eq!(expected, query); + assert_eq!(BytesMut::from("1"), redis_key); + assert!( hash_key.is_empty()); } #[test] @@ -703,13 +736,13 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x = 123 AND y = 965"); - let query_one = build_redis_ast_from_cql3(&ast, &table_cache_schema) + let query_one = build_redis_key_from_cql3(&ast, &table_cache_schema) .ok() .unwrap(); let ast = build_query("SELECT * FROM foo WHERE y = 965 AND z = 1 AND x = 123"); - let query_two = build_redis_ast_from_cql3(&ast, &table_cache_schema) + let query_two = build_redis_key_from_cql3(&ast, &table_cache_schema) .ok() .unwrap(); @@ -722,46 +755,36 @@ mod test { fn range_exclusive_test() { let table_cache_schema = TableCacheSchema { partition_key: vec!["z".to_string()], - range_key: vec!["x".to_string()], + range_key: vec![], }; let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x > 123 AND x < 999"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema) + let (redis_key, hash_key) = build_redis_key_from_cql3(&ast, &table_cache_schema) .ok() .unwrap(); - let expected = RedisFrame::Array(vec![ - RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), - RedisFrame::BulkString(Bytes::from_static(b"1")), - RedisFrame::BulkString(Bytes::from_static(b"[124")), - RedisFrame::BulkString(Bytes::from_static(b"]998")), - ]); + assert_eq!(BytesMut::from( "1" ), redis_key); + assert_eq!(BytesMut::from( "* WHERE x > 123 AND x < 999"), hash_key); - assert_eq!(expected, query); } #[test] fn range_inclusive_test() { let table_cache_schema = TableCacheSchema { partition_key: vec!["z".to_string()], - range_key: vec!["x".to_string()], + range_key: vec![], }; let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x >= 123 AND x <= 999"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema) + let (redis_key, hash_key) = build_redis_key_from_cql3(&ast, &table_cache_schema) .ok() .unwrap(); - let expected = RedisFrame::Array(vec![ - RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), - RedisFrame::BulkString(Bytes::from_static(b"1")), - RedisFrame::BulkString(Bytes::from_static(b"[123")), - RedisFrame::BulkString(Bytes::from_static(b"]999")), - ]); + assert_eq!(BytesMut::from("1"), redis_key); + assert_eq!(BytesMut::from( "* WHERE x >= 123 AND x <= 999"), hash_key); - assert_eq!(expected, query); } #[test] @@ -773,18 +796,12 @@ mod test { let ast = build_query("SELECT * FROM foo WHERE z = 1"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema) + let (redis_key, hash_key) = build_redis_key_from_cql3(&ast, &table_cache_schema) .ok() .unwrap(); - let expected = RedisFrame::Array(vec![ - RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), - RedisFrame::BulkString(Bytes::from_static(b"1")), - RedisFrame::BulkString(Bytes::from_static(b"-")), - RedisFrame::BulkString(Bytes::from_static(b"+")), - ]); - - assert_eq!(expected, query); + assert_eq!(BytesMut::from( "1" ), redis_key); +assert_eq!( BytesMut::from( "* WHERE "), hash_key); } #[test] @@ -794,58 +811,41 @@ mod test { range_key: vec![], }; - let ast = build_query("SELECT * FROM foo WHERE z = 1 AND y = 2"); + let ast = build_query("SELECT thing FROM foo WHERE z = 1 AND y = 2"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema) + let (key, hash_key) = build_redis_key_from_cql3(&ast, &table_cache_schema) .ok() .unwrap(); - let expected = RedisFrame::Array(vec![ - RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), - RedisFrame::BulkString(Bytes::from_static(b"12")), - RedisFrame::BulkString(Bytes::from_static(b"-")), - RedisFrame::BulkString(Bytes::from_static(b"+")), - ]); - - assert_eq!(expected, query); + assert_eq!(BytesMut::from("1:2"), key); + assert_eq!(BytesMut::from( "thing WHERE "), hash_key); } #[test] fn open_range_test() { let table_cache_schema = TableCacheSchema { partition_key: vec!["z".to_string()], - range_key: vec!["x".to_string()], + range_key: vec![], }; let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x >= 123"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema) + let (redis_key, hash_key) = build_redis_key_from_cql3(&ast, &table_cache_schema) .ok() .unwrap(); - let expected = RedisFrame::Array(vec![ - RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), - RedisFrame::BulkString(Bytes::from_static(b"1")), - RedisFrame::BulkString(Bytes::from_static(b"[123")), - RedisFrame::BulkString(Bytes::from_static(b"+")), - ]); - - assert_eq!(expected, query); + assert_eq!(BytesMut::from("1"), redis_key); + assert_eq!(BytesMut::from("* WHERE x >= 123"), hash_key ); let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x <= 123"); - let query = build_redis_ast_from_cql3(&ast, &table_cache_schema) + let (redis_key, hash_key) = build_redis_key_from_cql3(&ast, &table_cache_schema) .ok() .unwrap(); - let expected = RedisFrame::Array(vec![ - RedisFrame::BulkString(Bytes::from_static(b"ZRANGEBYLEX")), - RedisFrame::BulkString(Bytes::from_static(b"1")), - RedisFrame::BulkString(Bytes::from_static(b"-")), - RedisFrame::BulkString(Bytes::from_static(b"]123")), - ]); + assert_eq!(BytesMut::from("1"), redis_key); + assert_eq!(BytesMut::from("* WHERE x <= 123"), hash_key ); - assert_eq!(expected, query); } #[tokio::test] From 007c9c19a54e149b3206b6bb8fb43036afc49294 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Thu, 14 Apr 2022 15:50:15 +0100 Subject: [PATCH 22/60] fixed tests --- shotover-proxy/src/frame/cassandra.rs | 102 ++-- shotover-proxy/src/transforms/redis/cache.rs | 492 +++++++++++-------- shotover-proxy/tests/helpers/cassandra.rs | 1 - 3 files changed, 324 insertions(+), 271 deletions(-) diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index 5c8587bee..828528ccf 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -8,8 +8,13 @@ use cassandra_protocol::frame::frame_error::ErrorBody; use cassandra_protocol::frame::frame_query::BodyReqQuery; use cassandra_protocol::frame::frame_request::RequestBody; use cassandra_protocol::frame::frame_response::ResponseBody; -use cassandra_protocol::frame::frame_result::{BodyResResultPrepared, BodyResResultRows, BodyResResultSetKeyspace, ColSpec, ResResultBody, RowsMetadata, RowsMetadataFlags}; -use cassandra_protocol::frame::{Direction, Flags, Frame as RawCassandraFrame, Opcode, Serialize, StreamId, Version}; +use cassandra_protocol::frame::frame_result::{ + BodyResResultPrepared, BodyResResultRows, BodyResResultSetKeyspace, ColSpec, ResResultBody, + RowsMetadata, RowsMetadataFlags, +}; +use cassandra_protocol::frame::{ + Direction, Flags, Frame as RawCassandraFrame, Opcode, Serialize, StreamId, Version, +}; use cassandra_protocol::query::{QueryParams, QueryValues}; use cassandra_protocol::types::blob::Blob; use cassandra_protocol::types::cassandra_type::CassandraType; @@ -1029,78 +1034,71 @@ pub enum CassandraResult { impl Serialize for CassandraResult { fn serialize(&self, cursor: &mut Cursor<&mut Vec>) { - let res_result_body : ResResultBody = match self { - CassandraResult::Rows { value,metadata } => { - match value { - MessageValue::Rows(rows) => { - let mut rows_content: Vec> = Vec::with_capacity(rows.len()); - for row in rows { - let mut row_data = Vec::with_capacity(row.len()); - for element in row { - let b = cassandra_protocol::types::value::Bytes::from(element.clone()); - row_data.push(CBytes::new(b.into_inner())); - } - rows_content.push(row_data); + let res_result_body: ResResultBody = match self { + CassandraResult::Rows { value, metadata } => match value { + MessageValue::Rows(rows) => { + let mut rows_content: Vec> = Vec::with_capacity(rows.len()); + for row in rows { + let mut row_data = Vec::with_capacity(row.len()); + for element in row { + let b = cassandra_protocol::types::value::Bytes::from(element.clone()); + row_data.push(CBytes::new(b.into_inner())); } - let body_res_result_rows = BodyResResultRows { - metadata: metadata.clone(), - rows_count: rows.len() as CInt, - rows_content - }; - ResResultBody::Rows(body_res_result_rows) + rows_content.push(row_data); } - _ => ResResultBody::Void + let body_res_result_rows = BodyResResultRows { + metadata: metadata.clone(), + rows_count: rows.len() as CInt, + rows_content, + }; + ResResultBody::Rows(body_res_result_rows) } - } - CassandraResult::SetKeyspace( keyspace ) => { - ResResultBody::SetKeyspace(*keyspace.clone()) - } - CassandraResult::Prepared( prepared ) => { - ResResultBody::Prepared( *prepared.clone() ) - } + _ => ResResultBody::Void, + }, + CassandraResult::SetKeyspace(keyspace) => ResResultBody::SetKeyspace(*keyspace.clone()), + CassandraResult::Prepared(prepared) => ResResultBody::Prepared(*prepared.clone()), CassandraResult::SchemaChange(schema_change) => { - ResResultBody::SchemaChange( schema_change.clone() ) - } - CassandraResult::Void => { - ResResultBody::Void + ResResultBody::SchemaChange(schema_change.clone()) } + CassandraResult::Void => ResResultBody::Void, }; res_result_body.serialize(cursor); } } impl CassandraResult { - pub fn from_cursor( - cursor: &mut Cursor<&[u8]>, - version: Version, - ) -> Result { - + pub fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> Result { let res_result_body = ResResultBody::from_cursor(cursor, version)?; Ok(match res_result_body { ResResultBody::Void => CassandraResult::Void, - ResResultBody::Rows( body_res_result_rows) => { - let mut value : Vec> = Vec::with_capacity(body_res_result_rows.rows_content.len()); - for row in &body_res_result_rows.rows_content { - let mut row_values = Vec::with_capacity( body_res_result_rows.metadata.col_specs.len()); - for (cbytes,colspec) in row.iter().zip( body_res_result_rows.metadata.col_specs.iter() ) { - row_values.push( MessageValue::build_value_from_cstar_col_type(colspec, cbytes) ); - } - value.push(row_values); + ResResultBody::Rows(body_res_result_rows) => { + let mut value: Vec> = + Vec::with_capacity(body_res_result_rows.rows_content.len()); + for row in &body_res_result_rows.rows_content { + let mut row_values = + Vec::with_capacity(body_res_result_rows.metadata.col_specs.len()); + for (cbytes, colspec) in row + .iter() + .zip(body_res_result_rows.metadata.col_specs.iter()) + { + row_values.push(MessageValue::build_value_from_cstar_col_type( + colspec, cbytes, + )); } + value.push(row_values); + } CassandraResult::Rows { - value : MessageValue::Rows(value), + value: MessageValue::Rows(value), metadata: body_res_result_rows.metadata, } - }, - ResResultBody::SetKeyspace(keyspace) => { - CassandraResult::SetKeyspace( Box::new( keyspace )) } - ResResultBody::Prepared(prepared) => { - CassandraResult::Prepared(Box::new( prepared)) + ResResultBody::SetKeyspace(keyspace) => { + CassandraResult::SetKeyspace(Box::new(keyspace)) } + ResResultBody::Prepared(prepared) => CassandraResult::Prepared(Box::new(prepared)), ResResultBody::SchemaChange(schema_change) => { - CassandraResult::SchemaChange( schema_change ) + CassandraResult::SchemaChange(schema_change) } }) } diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index e61e0dc43..f97b9993f 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -1,8 +1,7 @@ - use crate::config::topology::TopicHolder; use crate::error::ChainResponse; use crate::frame::cassandra::CQLStatement; -use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, CQL, Frame, RedisFrame}; +use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame, RedisFrame, CQL}; use crate::message::{Message, Messages, QueryType}; use crate::transforms::chain::TransformChain; use crate::transforms::{ @@ -11,16 +10,16 @@ use crate::transforms::{ use anyhow::Result; use async_trait::async_trait; use bytes::BytesMut; +use cassandra_protocol::frame::Serialize; use cassandra_protocol::frame::Version; use cql3_parser::cassandra_statement::CassandraStatement; use cql3_parser::common::{Operand, RelationElement, RelationOperator}; -use serde::{Deserialize}; +use cql3_parser::select::{Named, Select, SelectElement}; +use itertools::Itertools; +use serde::Deserialize; use std::collections::{BTreeMap, HashMap}; use std::io::Cursor; use tracing_log::log::{info, warn}; -use cassandra_protocol::frame::Serialize; -use cql3_parser::select::{Named, Select, SelectElement}; -use itertools::Itertools; /* Uses redis as a cache. Data is stored in Redis as a Hash. @@ -76,7 +75,6 @@ pub struct SimpleRedisCache { caching_schema: HashMap, } - impl SimpleRedisCache { fn get_name(&self) -> &'static str { "SimpleRedisCache" @@ -84,8 +82,10 @@ impl SimpleRedisCache { /// Build the messages for the cache query from the cassandra request messages. /// returns the Redis Messages or a `CacheableState:Err` or `CacheableState::Skip` as the error - fn build_cache_query(&mut self, - cassandra_messages: &mut Messages, ) -> Result { + fn build_cache_query( + &mut self, + cassandra_messages: &mut Messages, + ) -> Result { let mut messages_redis_request = Vec::with_capacity(cassandra_messages.len()); for cass_request in cassandra_messages { match &mut cass_request.frame() { @@ -96,37 +96,40 @@ impl SimpleRedisCache { let statement = &cql_statement.statement; info!("build_cache_query processing cacheable state"); if let Some(table_cache_schema) = - self.caching_schema.get(table_name.as_str()) + self.caching_schema.get(table_name.as_str()) { - match build_redis_key_from_cql3( - statement, - table_cache_schema, - ) { - Ok((redis_key,hash_key)) => { - let commands_buffer = vec![ - RedisFrame::BulkString("HGET".into()), - RedisFrame::BulkString(redis_key.into()), - RedisFrame::BulkString(hash_key.into()), ]; - - messages_redis_request.push(Message::from_frame( - Frame::Redis(RedisFrame::Array(commands_buffer)), - )); - }, - Err(err_state) => {state = err_state;} + match build_redis_key_from_cql3(statement, table_cache_schema) { + Ok((redis_key, hash_key)) => { + let commands_buffer = vec![ + RedisFrame::BulkString("HGET".into()), + RedisFrame::BulkString(redis_key.into()), + RedisFrame::BulkString(hash_key.into()), + ]; + + messages_redis_request.push(Message::from_frame( + Frame::Redis(RedisFrame::Array(commands_buffer)), + )); + } + Err(err_state) => { + state = err_state; + } } } else { - state = CacheableState::Skip( - format!("Table {} not in caching list", table_name) - ); + state = CacheableState::Skip(format!( + "Table {} not in caching list", + table_name + )); } } else { - state = CacheableState::Skip(format!("{} is not a readable query",cql_statement)); + state = CacheableState::Skip(format!( + "{} is not a readable query", + cql_statement + )); } match state { - CacheableState::Err(_) | - CacheableState::Skip(_) => { return Err(state) } - _ => {}, + CacheableState::Err(_) | CacheableState::Skip(_) => return Err(state), + _ => {} } } } @@ -143,7 +146,11 @@ impl SimpleRedisCache { /// unwraps redis response messages into cassandra messages. It does this by replacing the Cassandra /// request messages with their corresponding Cassandra response messages and returns them. /// Result is either the modified message_cass_request (now response messages) or and CacheableState::Err. - fn unwrap_cache_response(&self, messages_redis_response: Messages, mut cassandra_messages: Messages) -> Result { + fn unwrap_cache_response( + &self, + messages_redis_response: Messages, + mut cassandra_messages: Messages, + ) -> Result { // Replace cass_request messages with cassandra responses in place. // We reuse the vec like this to save allocations. let mut messages_redis_response_iter = messages_redis_response.into_iter(); @@ -153,33 +160,38 @@ impl SimpleRedisCache { for cass_request in cassandra_messages.iter_mut() { // the responses for this request - let cassandra_result : Result = if let Some(Frame::Cassandra(frame)) = &mut cass_request.frame() { - let queries = frame.operation.queries(); - if queries.len() != 1 { - Err(CacheableState::Err("Cacheable Cassandra query must be only one statement".into())) - } else if let Some(mut redis_response) = messages_redis_response_iter.next() { - match redis_response.frame() { - Some(Frame::Redis(RedisFrame::BulkString(redis_bytes))) => { - // Redis response contains serialized version of result struct from CassandraOperation::Result( result ) - let x = redis_bytes.iter().copied().collect_vec(); - let mut cursor = Cursor::new(x.as_slice()); - let answer = CassandraResult::from_cursor(&mut cursor, Version::V4); - if let Ok(result) = answer { - Ok(result) - } else { - Err(CacheableState::Err(answer.err().unwrap().to_string())) + let cassandra_result: Result = + if let Some(Frame::Cassandra(frame)) = &mut cass_request.frame() { + let queries = frame.operation.queries(); + if queries.len() != 1 { + Err(CacheableState::Err( + "Cacheable Cassandra query must be only one statement".into(), + )) + } else if let Some(mut redis_response) = messages_redis_response_iter.next() { + match redis_response.frame() { + Some(Frame::Redis(RedisFrame::BulkString(redis_bytes))) => { + // Redis response contains serialized version of result struct from CassandraOperation::Result( result ) + let x = redis_bytes.iter().copied().collect_vec(); + let mut cursor = Cursor::new(x.as_slice()); + let answer = CassandraResult::from_cursor(&mut cursor, Version::V4); + if let Ok(result) = answer { + Ok(result) + } else { + Err(CacheableState::Err(answer.err().unwrap().to_string())) + } } + _ => Err(CacheableState::Err( + "No Redis frame in Redis response".into(), + )), } - _ => Err(CacheableState::Err("No Redis frame in Redis response".into())) + } else { + Err(CacheableState::Err("Redis response was None".into())) } } else { - Err(CacheableState::Err("Redis response was None".into())) - } - } else { - Ok(CassandraResult::Void) - }; - if let Err(state) = cassandra_result { - return Err( state ); + Ok(CassandraResult::Void) + }; + if let Err(state) = cassandra_result { + return Err(state); } *cass_request = Message::from_frame(Frame::Cassandra(CassandraFrame { @@ -221,7 +233,10 @@ impl SimpleRedisCache { info!("read_from_cache called"); // build the cache query - let messages_redis_request = self.build_cache_query(&mut cassandra_messages).ok().unwrap(); + let messages_redis_request = self + .build_cache_query(&mut cassandra_messages) + .ok() + .unwrap(); // execute the cache query info!("read_from_cache calling cache_chain.process_request"); @@ -243,61 +258,80 @@ impl SimpleRedisCache { } /// clear the cache for the entry. - fn clear_table_cache(&mut self, cql_statement: &CQLStatement, table_cache_schema: &TableCacheSchema) -> Option { + fn clear_table_cache( + &mut self, + cql_statement: &CQLStatement, + table_cache_schema: &TableCacheSchema, + ) -> Option { // TODO is it possible to return the future and process in parallel? let statement = &cql_statement.statement; - if let Ok((redis_key,_hash_key)) = build_redis_key_from_cql3(statement, table_cache_schema) { + if let Ok((redis_key, _hash_key)) = build_redis_key_from_cql3(statement, table_cache_schema) + { let commands_buffer: Vec = vec![ RedisFrame::BulkString("GETDEL".into()), RedisFrame::BulkString(redis_key.into()), ]; - Some(Message::from_frame(Frame::Redis(RedisFrame::Array(commands_buffer)))) + Some(Message::from_frame(Frame::Redis(RedisFrame::Array( + commands_buffer, + )))) } else { None } } - /// calls the next transform and process the result for caching. - async fn execute_upstream_and_process_result<'a>(&mut self, message_wrapper: Wrapper<'a> + async fn execute_upstream_and_process_result<'a>( + &mut self, + message_wrapper: Wrapper<'a>, ) -> ChainResponse { let mut orig_messages = message_wrapper.messages.clone(); - let orig_cql : Option<&mut CQL> = orig_messages.iter_mut() - .filter_map( |message| { - if let Some(Frame::Cassandra(CassandraFrame { operation: CassandraOperation::Query{query,..}, .. })) = message.frame() + let orig_cql: Option<&mut CQL> = orig_messages + .iter_mut() + .filter_map(|message| { + if let Some(Frame::Cassandra(CassandraFrame { + operation: CassandraOperation::Query { query, .. }, + .. + })) = message.frame() { Some(query) } else { None } - - } ).next(); + }) + .next(); let result_messages = &mut message_wrapper.call_next_transform().await?; if orig_cql.is_some() { - let mut cache_messages: Vec = vec!(); - for (response, cql_statement) in result_messages.iter_mut().zip(orig_cql.unwrap().statements.iter()) { - if let Some(Frame::Cassandra(CassandraFrame { operation: CassandraOperation::Result(result), .. })) = response.frame() { + let mut cache_messages: Vec = vec![]; + for (response, cql_statement) in result_messages + .iter_mut() + .zip(orig_cql.unwrap().statements.iter()) + { + if let Some(Frame::Cassandra(CassandraFrame { + operation: CassandraOperation::Result(result), + .. + })) = response.frame() + { match is_cacheable(cql_statement) { - CacheableState::Update(table_name) | - CacheableState::Delete(table_name) => { - if let Some(table_cache_schema) = self.caching_schema.get(&table_name ) - { + CacheableState::Update(table_name) | CacheableState::Delete(table_name) => { + if let Some(table_cache_schema) = self.caching_schema.get(&table_name) { let table_schema = table_cache_schema.clone(); - if let Some(fut_message) = self.clear_table_cache(cql_statement, &table_schema) { + if let Some(fut_message) = + self.clear_table_cache(cql_statement, &table_schema) + { cache_messages.push(fut_message); } } else { - info!( "table {} is not being cached", table_name ); + info!("table {} is not being cached", table_name); } } CacheableState::Read(table_name) => { let statement = &cql_statement.statement; - if let Some(table_cache_schema) = self.caching_schema.get(table_name.as_str()) + if let Some(table_cache_schema) = + self.caching_schema.get(table_name.as_str()) { - if let Ok((redis_key,hash_key)) = build_redis_key_from_cql3( - statement, - table_cache_schema, - ) { + if let Ok((redis_key, hash_key)) = + build_redis_key_from_cql3(statement, table_cache_schema) + { let mut encoded: Vec = Vec::new(); let mut cursor = Cursor::new(&mut encoded); result.serialize(&mut cursor); @@ -309,27 +343,28 @@ impl SimpleRedisCache { RedisFrame::BulkString(encoded.into()), ]; - cache_messages.push( Message::from_frame(Frame::Redis(RedisFrame::Array(commands_buffer)))); - + cache_messages.push(Message::from_frame(Frame::Redis( + RedisFrame::Array(commands_buffer), + ))); } } } - CacheableState::Skip(_reason) | - CacheableState::Err(_reason) => { + CacheableState::Skip(_reason) | CacheableState::Err(_reason) => { // do nothing } } } } - if ! cache_messages.is_empty() { + if !cache_messages.is_empty() { let result = self .cache_chain .process_request( Wrapper::new_with_chain_name(cache_messages, self.cache_chain.name.clone()), "clientdetailstodo".to_string(), - ).await; + ) + .await; if result.is_err() { - warn!( "Cache error: {}", result.err().unwrap()); + warn!("Cache error: {}", result.err().unwrap()); } } } @@ -337,7 +372,6 @@ impl SimpleRedisCache { } } - /// Determines if a statement is cacheable. Cacheable statements have several common /// properties as well as operation specific properties. /// Common properties include @@ -345,13 +379,11 @@ impl SimpleRedisCache { /// * must not contain a parsing error /// * fn is_cacheable(cql_statement: &CQLStatement) -> CacheableState { - // check issues common to all cql_statements if cql_statement.has_error { - return CacheableState::Skip( "CQL statement has error".into()); + return CacheableState::Skip("CQL statement has error".into()); } - if let Some(table_name) = CQLStatement::get_table_name(&cql_statement.statement) - { + if let Some(table_name) = CQLStatement::get_table_name(&cql_statement.statement) { let has_params = CQLStatement::has_params(&cql_statement.statement); match &cql_statement.statement { @@ -363,15 +395,15 @@ fn is_cacheable(cql_statement: &CQLStatement) -> CacheableState { } else if select.where_clause.is_empty() { CacheableState::Skip("Can not cache if where clause is empty".into()) /* } else if !select.columns.is_empty() { - if select.columns.len() == 1 && select.columns[0].eq(&SelectElement::Star) { - CacheableState::Read - } else { - CacheableState::Skip( - "Can not cache if columns other than '*' are selected".into(), - ) - } + if select.columns.len() == 1 && select.columns[0].eq(&SelectElement::Star) { + CacheableState::Read + } else { + CacheableState::Skip( + "Can not cache if columns other than '*' are selected".into(), + ) + } - */ + */ } else { CacheableState::Read(table_name.into()) } @@ -390,16 +422,16 @@ fn is_cacheable(cql_statement: &CQLStatement) -> CacheableState { for assignment_element in &update.assignments { if assignment_element.operator.is_some() { info!( - "Clearing {} cache: {} has calculations in values", - update.table_name, assignment_element.name - ); + "Clearing {} cache: {} has calculations in values", + update.table_name, assignment_element.name + ); return CacheableState::Delete(table_name.into()); } if assignment_element.name.idx.is_some() { info!( - "Clearing {} cache: {} is an indexed columns", - update.table_name, assignment_element.name - ); + "Clearing {} cache: {} is an indexed columns", + update.table_name, assignment_element.name + ); return CacheableState::Delete(table_name.into()); } } @@ -421,57 +453,70 @@ fn build_query_redis_key_from_value_map( table_cache_schema: &TableCacheSchema, query_values: &BTreeMap>, ) -> Result { - let mut key : Vec = vec!(); - for c_name in &table_cache_schema.partition_key { - let column_name = c_name.to_lowercase(); - match query_values.get( column_name.as_str() ) { - None => { return Err( CacheableState::Skip( format!("Partition key not complete. missing segment {}", column_name )));}, - Some( relation_elements ) => { - if relation_elements.len() > 1 { - return Err( CacheableState::Skip( format!("partition key segment {} has more than one relationship", column_name))); - } - if !key.is_empty() { - key.push( b':' ) - } - key.extend(relation_elements[0].value.to_string().as_bytes() ); + let mut key: Vec = vec![]; + for c_name in &table_cache_schema.partition_key { + let column_name = c_name.to_lowercase(); + match query_values.get(column_name.as_str()) { + None => { + return Err(CacheableState::Skip(format!( + "Partition key not complete. missing segment {}", + column_name + ))); + } + Some(relation_elements) => { + if relation_elements.len() > 1 { + return Err(CacheableState::Skip(format!( + "partition key segment {} has more than one relationship", + column_name + ))); + } + if !key.is_empty() { + key.push(b':') } + key.extend(relation_elements[0].value.to_string().as_bytes()); } } - let mut skipping = false; - - for c_name in &table_cache_schema.range_key { - let column_name = c_name.to_lowercase(); - match query_values.get( column_name.as_str() ) { - None => { skipping = true; }, - Some( relation_elements ) => { - if skipping { - // we skipped an earlier column so this is an error. - return Err(CacheableState::Err( - "Columns in the middle of the range key were skipped".into(), - )); - } - if relation_elements.len() > 1 { - return Err( CacheableState::Skip( format!("partition key segment {} has more than one relationship", column_name))); - } - if !key.is_empty() { - key.push( b':' ) - } - key.extend(relation_elements[0].value.to_string().as_bytes() ); - /* - let mut my_elements = relation_elements.clone(); - - my_elements.sort(); - for relation_element in my_elements { - if !redis_key.is_empty() { - redis_key.extend(':'); - } - redis_key.extend( relation_element.value.to_string().as_bytes() ); - } -*/ + } + let mut skipping = false; + + for c_name in &table_cache_schema.range_key { + let column_name = c_name.to_lowercase(); + match query_values.get(column_name.as_str()) { + None => { + skipping = true; + } + Some(relation_elements) => { + if skipping { + // we skipped an earlier column so this is an error. + return Err(CacheableState::Err( + "Columns in the middle of the range key were skipped".into(), + )); + } + if relation_elements.len() > 1 { + return Err(CacheableState::Skip(format!( + "partition key segment {} has more than one relationship", + column_name + ))); + } + if !key.is_empty() { + key.push(b':') } + key.extend(relation_elements[0].value.to_string().as_bytes()); + /* + let mut my_elements = relation_elements.clone(); + + my_elements.sort(); + for relation_element in my_elements { + if !redis_key.is_empty() { + redis_key.extend(':'); + } + redis_key.extend( relation_element.value.to_string().as_bytes() ); + } + */ } } - Ok(BytesMut::from( key.as_slice() )) + } + Ok(BytesMut::from(key.as_slice())) } /// build the redis key for the query. @@ -480,9 +525,8 @@ fn build_query_redis_key_from_value_map( fn build_query_redis_hash_from_value_map( table_cache_schema: &TableCacheSchema, query_values: &BTreeMap>, - select : &Select + select: &Select, ) -> Result { - let mut my_values = query_values.clone(); for c_name in &table_cache_schema.partition_key { let column_name = c_name.to_lowercase(); @@ -493,33 +537,43 @@ fn build_query_redis_hash_from_value_map( my_values.remove(&column_name); } - let mut str = if select.columns.is_empty() { - String::from( "WHERE ") + let mut str = if select.columns.is_empty() { + String::from("WHERE ") } else { - let mut tmp = select.columns.iter().map(|select_element| { - match select_element { - SelectElement::Star => { SelectElement::Star } - SelectElement::Column(named) => { SelectElement::Column(Named { name: named.name.to_lowercase(), alias: match &named.alias{ - None => None, - Some(name) => Some(name.to_lowercase()), - } }) }, - SelectElement::Function(named) => {SelectElement::Function(Named { name: named.name.to_lowercase(), alias: match &named.alias{ - None => None, - Some(name) => Some(name.to_lowercase()), - } })} - } - }).join( ", "); - tmp.push_str( " WHERE "); + let mut tmp = select + .columns + .iter() + .map(|select_element| match select_element { + SelectElement::Star => SelectElement::Star, + SelectElement::Column(named) => SelectElement::Column(Named { + name: named.name.to_lowercase(), + alias: named.alias.as_ref().map(|name| name.to_lowercase()), + }), + SelectElement::Function(named) => SelectElement::Function(Named { + name: named.name.to_lowercase(), + alias: named.alias.as_ref().map(|name| name.to_lowercase()), + }), + }) + .join(", "); + tmp.push_str(" WHERE "); tmp }; - str.push_str( my_values.iter_mut().sorted() - .flat_map( |(_k,v)| v.iter() ).join(" AND ").as_str()); - - Ok(BytesMut::from( str.as_str() )) + str.push_str( + my_values + .iter_mut() + .sorted() + .flat_map(|(_k, v)| v.iter()) + .join(" AND ") + .as_str(), + ); + + Ok(BytesMut::from(str.as_str())) } - -fn populate_value_map_from_where_clause(value_map : &mut BTreeMap>, where_clause : &[RelationElement]) { +fn populate_value_map_from_where_clause( + value_map: &mut BTreeMap>, + where_clause: &[RelationElement], +) { for relation_element in where_clause { let column_name = relation_element.obj.to_string().to_lowercase(); let value = value_map.get_mut(column_name.as_str()); @@ -534,37 +588,43 @@ fn populate_value_map_from_where_clause(value_map : &mut BTreeMap Result<(BytesMut,BytesMut), CacheableState> { - let mut value_map : BTreeMap> = BTreeMap::new(); +) -> Result<(BytesMut, BytesMut), CacheableState> { + let mut value_map: BTreeMap> = BTreeMap::new(); match statement { CassandraStatement::Select(select) => { - populate_value_map_from_where_clause( &mut value_map, &select.where_clause ); - Ok((build_query_redis_key_from_value_map(table_cache_schema, &value_map)?, - build_query_redis_hash_from_value_map(table_cache_schema, &value_map, &select)?)) + populate_value_map_from_where_clause(&mut value_map, &select.where_clause); + Ok(( + build_query_redis_key_from_value_map(table_cache_schema, &value_map)?, + build_query_redis_hash_from_value_map(table_cache_schema, &value_map, select)?, + )) } CassandraStatement::Insert(insert) => { for (c_name, operand) in insert.get_value_map().into_iter() { let column_name = c_name.to_lowercase(); - let relation_element = RelationElement{ - obj: Operand::Column( column_name.clone()), + let relation_element = RelationElement { + obj: Operand::Column(column_name.clone()), oper: RelationOperator::Equal, value: operand.clone(), }; let value = value_map.get_mut(column_name.as_str()); if let Some(vec) = value { - vec.push( relation_element) + vec.push(relation_element) } else { value_map.insert(column_name, vec![relation_element]); }; } - Ok((build_query_redis_key_from_value_map(table_cache_schema, &value_map)?, - BytesMut::new() )) + Ok(( + build_query_redis_key_from_value_map(table_cache_schema, &value_map)?, + BytesMut::new(), + )) } CassandraStatement::Update(update) => { - populate_value_map_from_where_clause( &mut value_map, &update.where_clause); - Ok((build_query_redis_key_from_value_map(table_cache_schema, &value_map)?, - BytesMut::new() )) + populate_value_map_from_where_clause(&mut value_map, &update.where_clause); + Ok(( + build_query_redis_key_from_value_map(table_cache_schema, &value_map)?, + BytesMut::new(), + )) } _ => unreachable!( "{} should not be passed to build_redis_key_from_cql3", @@ -573,7 +633,6 @@ fn build_redis_key_from_cql3( } } - #[async_trait] impl Transform for SimpleRedisCache { async fn transform<'a>(&'a mut self, mut message_wrapper: Wrapper<'a>) -> ChainResponse { @@ -602,31 +661,29 @@ impl Transform for SimpleRedisCache { // If there are no write queries (all queries are reads) we can read the cache if read_cache { - match self - .read_from_cache(message_wrapper.messages.clone()) - .await - { + match self.read_from_cache(message_wrapper.messages.clone()).await { Ok(cr) => return Ok(cr), Err(inner_state) => match &inner_state { CacheableState::Skip(reason) => { tracing::info!("Cache skipped: {} ", reason); - self.execute_upstream_and_process_result( message_wrapper ).await + self.execute_upstream_and_process_result(message_wrapper) + .await } CacheableState::Err(reason) => { tracing::error!("Cache failed: {} ", reason); message_wrapper.call_next_transform().await } _ => { - unreachable!("should not find read, update or delete as an error."); + unreachable!("should not find read, update or delete as an error."); } }, } } else { - self.execute_upstream_and_process_result( message_wrapper ).await + self.execute_upstream_and_process_result(message_wrapper) + .await } } - fn validate(&self) -> Vec { let mut errors = self .cache_chain @@ -645,13 +702,15 @@ impl Transform for SimpleRedisCache { #[cfg(test)] mod test { - use crate::frame::{RedisFrame, CQL}; + use crate::frame::CQL; use crate::transforms::chain::TransformChain; use crate::transforms::debug::printer::DebugPrinter; use crate::transforms::null::Null; - use crate::transforms::redis::cache::{build_query_redis_hash_from_value_map, build_redis_key_from_cql3, SimpleRedisCache, TableCacheSchema}; + use crate::transforms::redis::cache::{ + build_redis_key_from_cql3, SimpleRedisCache, TableCacheSchema, + }; use crate::transforms::{Transform, Transforms}; - use bytes::{Bytes, BytesMut}; + use bytes::BytesMut; use cql3_parser::cassandra_statement::CassandraStatement; use std::collections::HashMap; @@ -674,8 +733,8 @@ mod test { .ok() .unwrap(); - assert_eq!( BytesMut::from( "1:123:965"), redis_key); - assert_eq!( BytesMut::from( "* WHERE "), hash_key ); + assert_eq!(BytesMut::from("1:123:965"), redis_key); + assert_eq!(BytesMut::from("* WHERE "), hash_key); } #[test] @@ -691,8 +750,8 @@ mod test { .ok() .unwrap(); - assert_eq!( BytesMut::from( "1"), redis_key); - assert!( hash_key.is_empty()); + assert_eq!(BytesMut::from("1"), redis_key); + assert!(hash_key.is_empty()); } #[test] @@ -708,7 +767,7 @@ mod test { .unwrap(); assert_eq!(BytesMut::from("1:'yo'"), redis_key); - assert!( hash_key.is_empty() ); + assert!(hash_key.is_empty()); } #[test] @@ -724,7 +783,7 @@ mod test { let (redis_key, hash_key) = result.ok().unwrap(); assert_eq!(BytesMut::from("1"), redis_key); - assert!( hash_key.is_empty()); + assert!(hash_key.is_empty()); } #[test] @@ -764,9 +823,8 @@ mod test { .ok() .unwrap(); - assert_eq!(BytesMut::from( "1" ), redis_key); - assert_eq!(BytesMut::from( "* WHERE x > 123 AND x < 999"), hash_key); - + assert_eq!(BytesMut::from("1"), redis_key); + assert_eq!(BytesMut::from("* WHERE x > 123 AND x < 999"), hash_key); } #[test] @@ -783,8 +841,7 @@ mod test { .unwrap(); assert_eq!(BytesMut::from("1"), redis_key); - assert_eq!(BytesMut::from( "* WHERE x >= 123 AND x <= 999"), hash_key); - + assert_eq!(BytesMut::from("* WHERE x >= 123 AND x <= 999"), hash_key); } #[test] @@ -800,8 +857,8 @@ mod test { .ok() .unwrap(); - assert_eq!(BytesMut::from( "1" ), redis_key); -assert_eq!( BytesMut::from( "* WHERE "), hash_key); + assert_eq!(BytesMut::from("1"), redis_key); + assert_eq!(BytesMut::from("* WHERE "), hash_key); } #[test] @@ -818,7 +875,7 @@ assert_eq!( BytesMut::from( "* WHERE "), hash_key); .unwrap(); assert_eq!(BytesMut::from("1:2"), key); - assert_eq!(BytesMut::from( "thing WHERE "), hash_key); + assert_eq!(BytesMut::from("thing WHERE "), hash_key); } #[test] @@ -835,7 +892,7 @@ assert_eq!( BytesMut::from( "* WHERE "), hash_key); .unwrap(); assert_eq!(BytesMut::from("1"), redis_key); - assert_eq!(BytesMut::from("* WHERE x >= 123"), hash_key ); + assert_eq!(BytesMut::from("* WHERE x >= 123"), hash_key); let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x <= 123"); @@ -844,8 +901,7 @@ assert_eq!( BytesMut::from( "* WHERE "), hash_key); .unwrap(); assert_eq!(BytesMut::from("1"), redis_key); - assert_eq!(BytesMut::from("* WHERE x <= 123"), hash_key ); - + assert_eq!(BytesMut::from("* WHERE x <= 123"), hash_key); } #[tokio::test] diff --git a/shotover-proxy/tests/helpers/cassandra.rs b/shotover-proxy/tests/helpers/cassandra.rs index 5767d5df9..08b9ba1d6 100644 --- a/shotover-proxy/tests/helpers/cassandra.rs +++ b/shotover-proxy/tests/helpers/cassandra.rs @@ -1,6 +1,5 @@ use cassandra_cpp::{stmt, Cluster, Error, Session, Value, ValueType}; use ordered_float::OrderedFloat; -use test_helpers::try_wait_for_socket_to_open; use tracing::info; pub fn cassandra_connection(contact_points: &str, port: u16) -> Session { From 09e972fc38b4f74c5687a355ad660970bd0f477c Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Wed, 20 Apr 2022 08:17:28 +0100 Subject: [PATCH 23/60] all tests pass --- Cargo.lock | 30 +++ shotover-proxy/Cargo.toml | 2 + .../cassandra-redis-cache/topology.yaml | 4 +- shotover-proxy/src/transforms/redis/cache.rs | 181 +++++++++++------- .../cassandra_int_tests/basic_driver_tests.rs | 158 ++++++++++++--- 5 files changed, 282 insertions(+), 93 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 99da3b555..114a90db7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -698,6 +698,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "endian-type" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" + [[package]] name = "enum_primitive" version = "0.1.1" @@ -1364,14 +1370,18 @@ version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "107a38013e91c04ddf31826b0d0dcc2e0d4ebedded8234cc0dc2b7bbd0c121e8" dependencies = [ + "aho-corasick", "atomic-shim", "crossbeam-epoch", "crossbeam-utils", "hashbrown 0.11.2", + "indexmap", "metrics", "num_cpus", + "ordered-float", "parking_lot 0.11.2", "quanta", + "radix_trie", "sketches-ddsketch", ] @@ -1438,6 +1448,15 @@ dependencies = [ "tempfile", ] +[[package]] +name = "nibble_vec" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a5d83df9f36fe23f0c3648c6bbb8b0298bb5f1939c8f2704431371f4b84d43" +dependencies = [ + "smallvec", +] + [[package]] name = "nix" version = "0.23.1" @@ -1967,6 +1986,16 @@ version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "941ba9d78d8e2f7ce474c015eea4d9c6d25b6a3327f9832ee29a4de27f91bbb8" +[[package]] +name = "radix_trie" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c069c179fcdc6a2fe24d8d18305cf085fdbd4f922c041943e203685d6a1c58fd" +dependencies = [ + "endian-type", + "nibble_vec", +] + [[package]] name = "rand" version = "0.8.5" @@ -2531,6 +2560,7 @@ dependencies = [ "itertools", "metrics", "metrics-exporter-prometheus", + "metrics-util", "nix", "nonzero_ext", "num", diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml index ee36dd82c..18b4bb94f 100644 --- a/shotover-proxy/Cargo.toml +++ b/shotover-proxy/Cargo.toml @@ -89,6 +89,8 @@ test-helpers = { path = "../test-helpers" } hex-literal = "0.3.3" nix = "0.23.0" reqwest = "0.11.6" +metrics-util = "0.12.0" + [[bench]] name = "redis_benches" diff --git a/shotover-proxy/example-configs/cassandra-redis-cache/topology.yaml b/shotover-proxy/example-configs/cassandra-redis-cache/topology.yaml index 836a43052..f865b19c2 100644 --- a/shotover-proxy/example-configs/cassandra-redis-cache/topology.yaml +++ b/shotover-proxy/example-configs/cassandra-redis-cache/topology.yaml @@ -9,10 +9,10 @@ chain_config: caching_schema: test_cache_keyspace_batch_insert.test_table: partition_key: [id] - range_key: [id] + range_key: [] test_cache_keyspace_simple.test_table: partition_key: [id] - range_key: [id] + range_key: [] chain: - RedisSinkSingle: remote_address: "127.0.0.1:6379" diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index f97b9993f..2f431aa0f 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -9,7 +9,7 @@ use crate::transforms::{ }; use anyhow::Result; use async_trait::async_trait; -use bytes::BytesMut; +use bytes::{Bytes, BytesMut}; use cassandra_protocol::frame::Serialize; use cassandra_protocol::frame::Version; use cql3_parser::cassandra_statement::CassandraStatement; @@ -18,8 +18,13 @@ use cql3_parser::select::{Named, Select, SelectElement}; use itertools::Itertools; use serde::Deserialize; use std::collections::{BTreeMap, HashMap}; +use std::fmt::{Display, Formatter}; use std::io::Cursor; -use tracing_log::log::{info, warn}; +use std::ops::Deref; +use cql3_parser::common_drop::CommonDrop; +use serde_yaml::Index; +use tracing_log::log::{trace, debug, info, warn, error}; +use metrics::{counter, Counter, register_counter}; /* Uses redis as a cache. Data is stored in Redis as a Hash. @@ -41,12 +46,27 @@ enum CacheableState { Update(String), /// string is the table name Delete(String), + /// string is the table being dropped + Drop(String), /// string is the reason for the skip Skip(String), /// string is the reason for the error Err(String), } +impl Display for CacheableState { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + CacheableState::Read(name) => { write!(f, "Reading {}", name )} + CacheableState::Update(name) => { write!(f, "Updating {}", name )} + CacheableState::Delete(name) => { write!(f, "Deleting {}", name )} + CacheableState::Drop(name) => { write!(f, "Dropping {}", name)} + CacheableState::Skip(txt) => { write!(f, "Skipping due to: {}", txt )} + CacheableState::Err(txt) => { write!(f, "Error due to: {}", txt )} + } + } +} + #[derive(Deserialize, Debug, Clone)] pub struct RedisConfig { pub caching_schema: HashMap, @@ -61,10 +81,13 @@ pub struct TableCacheSchema { impl RedisConfig { pub async fn get_transform(&self, topics: &TopicHolder) -> Result { + let missed_requests = register_counter!("cache_miss"); + Ok(Transforms::RedisCache(SimpleRedisCache { cache_chain: build_chain_from_config("cache_chain".to_string(), &self.chain, topics) .await?, caching_schema: self.caching_schema.clone(), + missed_requests, })) } } @@ -73,6 +96,7 @@ impl RedisConfig { pub struct SimpleRedisCache { cache_chain: TransformChain, caching_schema: HashMap, + missed_requests: Counter, } impl SimpleRedisCache { @@ -94,16 +118,18 @@ impl SimpleRedisCache { let mut state = is_cacheable(cql_statement); if let CacheableState::Read(table_name) = &mut state { let statement = &cql_statement.statement; - info!("build_cache_query processing cacheable state"); + debug!("build_cache_query processing cacheable state"); if let Some(table_cache_schema) = self.caching_schema.get(table_name.as_str()) { match build_redis_key_from_cql3(statement, table_cache_schema) { Ok((redis_key, hash_key)) => { + trace!( "Redis key: {:?}", std::str::from_utf8( redis_key.deref())); + trace!( "Hash key: {:?}", std::str::from_utf8( hash_key.deref())); let commands_buffer = vec![ RedisFrame::BulkString("HGET".into()), - RedisFrame::BulkString(redis_key.into()), - RedisFrame::BulkString(hash_key.into()), + RedisFrame::BulkString(redis_key), + RedisFrame::BulkString(hash_key), ]; messages_redis_request.push(Message::from_frame( @@ -111,6 +137,7 @@ impl SimpleRedisCache { )); } Err(err_state) => { + warn!( "build_cache_query err: {}", err_state ); state = err_state; } } @@ -128,7 +155,10 @@ impl SimpleRedisCache { } match state { - CacheableState::Err(_) | CacheableState::Skip(_) => return Err(state), + CacheableState::Err(_) | CacheableState::Skip(_) => { + debug!( "build_cache_query: {}", state ); + return Err(state) + }, _ => {} } } @@ -168,18 +198,29 @@ impl SimpleRedisCache { "Cacheable Cassandra query must be only one statement".into(), )) } else if let Some(mut redis_response) = messages_redis_response_iter.next() { + match redis_response.frame() { - Some(Frame::Redis(RedisFrame::BulkString(redis_bytes))) => { - // Redis response contains serialized version of result struct from CassandraOperation::Result( result ) - let x = redis_bytes.iter().copied().collect_vec(); - let mut cursor = Cursor::new(x.as_slice()); - let answer = CassandraResult::from_cursor(&mut cursor, Version::V4); - if let Ok(result) = answer { - Ok(result) - } else { - Err(CacheableState::Err(answer.err().unwrap().to_string())) + Some(Frame::Redis(redis_frame)) => { + match redis_frame { + + RedisFrame::SimpleString(s) => {Err(CacheableState::Err( "Redis returned a simple string".into() ))}, + RedisFrame::Error(e) => {return Err(CacheableState::Err(e.to_string()))}, + RedisFrame::Integer(i) => {Err(CacheableState::Err( "Redis returned an int value".into() ))}, + RedisFrame::BulkString(redis_bytes) => { + // Redis response contains serialized version of result struct from CassandraOperation::Result( result ) + let x = redis_bytes.iter().copied().collect_vec(); + let mut cursor = Cursor::new(x.as_slice()); + let answer = CassandraResult::from_cursor(&mut cursor, Version::V4); + if let Ok(result) = answer { + Ok(result) + } else { + Err(CacheableState::Err(answer.err().unwrap().to_string())) + }}, + RedisFrame::Array(a) => {Err(CacheableState::Err( "Redis returned an array value".into() ))}, + RedisFrame::Null => {self.missed_requests.increment(1); Err(CacheableState::Skip( "No cache results".into() ))} } } + _ => Err(CacheableState::Err( "No Redis frame in Redis response".into(), )), @@ -230,35 +271,33 @@ impl SimpleRedisCache { // + if the request is a CassandraOperation::Query then we consume a single message from messages_redis_response converting it to a cassandra response // * These are the cassandra responses that we return from the function. - info!("read_from_cache called"); + debug!("read_from_cache called"); // build the cache query let messages_redis_request = self - .build_cache_query(&mut cassandra_messages) - .ok() - .unwrap(); + .build_cache_query(&mut cassandra_messages)?; // execute the cache query - info!("read_from_cache calling cache_chain.process_request"); - - match self + debug!("read_from_cache calling cache_chain.process_request"); + let messages_redis_response = self .cache_chain .process_request( Wrapper::new_with_chain_name(messages_redis_request, self.cache_chain.name.clone()), "clientdetailstodo".to_string(), ) - .await - { - Ok(messages_redis_response) => { - info!("read_from_cache received OK from cache_chain.process_request"); - self.unwrap_cache_response(messages_redis_response, cassandra_messages) - } - Err(e) => Err(CacheableState::Err(format!("Redis error: {}", e))), - } + .await.map_err(|e| CacheableState::Err(format!("Redis error: {}", e)))?; + + debug!("read_from_cache received OK from cache_chain.process_request"); + self.unwrap_cache_response(messages_redis_response, cassandra_messages) + } + + /// Clears the cache for the entire table + fn clear_table_cache(&self ) -> Option { + Some(Message::from_frame(Frame::Redis(RedisFrame::BulkString( "FLUSHDB".into() )))) } - /// clear the cache for the entry. - fn clear_table_cache( + /// clear the cache for the single row specified by the redis_key + fn clear_row_cache( &mut self, cql_statement: &CQLStatement, table_cache_schema: &TableCacheSchema, @@ -268,7 +307,7 @@ impl SimpleRedisCache { if let Ok((redis_key, _hash_key)) = build_redis_key_from_cql3(statement, table_cache_schema) { let commands_buffer: Vec = vec![ - RedisFrame::BulkString("GETDEL".into()), + RedisFrame::BulkString("DEL".into()), RedisFrame::BulkString(redis_key.into()), ]; Some(Message::from_frame(Frame::Redis(RedisFrame::Array( @@ -316,14 +355,18 @@ impl SimpleRedisCache { if let Some(table_cache_schema) = self.caching_schema.get(&table_name) { let table_schema = table_cache_schema.clone(); if let Some(fut_message) = - self.clear_table_cache(cql_statement, &table_schema) + self.clear_row_cache(cql_statement, &table_schema) { cache_messages.push(fut_message); } } else { - info!("table {} is not being cached", table_name); + debug!("table {} is not being cached", table_name); } } + CacheableState::Drop(table_name) => { + info!("table {} dropped", table_name); + self.clear_table_cache(); + } CacheableState::Read(table_name) => { let statement = &cql_statement.statement; if let Some(table_cache_schema) = @@ -415,20 +458,23 @@ fn is_cacheable(cql_statement: &CQLStatement) -> CacheableState { CacheableState::Update(table_name.into()) } } + CassandraStatement::DropTable(_) => { + CacheableState::Drop( table_name.into() ) + } CassandraStatement::Update(update) => { if has_params || update.if_exists { CacheableState::Delete(table_name.into()) } else { for assignment_element in &update.assignments { if assignment_element.operator.is_some() { - info!( + debug!( "Clearing {} cache: {} has calculations in values", update.table_name, assignment_element.name ); return CacheableState::Delete(table_name.into()); } if assignment_element.name.idx.is_some() { - info!( + debug!( "Clearing {} cache: {} is an indexed columns", update.table_name, assignment_element.name ); @@ -452,10 +498,11 @@ fn is_cacheable(cql_statement: &CQLStatement) -> CacheableState { fn build_query_redis_key_from_value_map( table_cache_schema: &TableCacheSchema, query_values: &BTreeMap>, -) -> Result { +) -> Result { let mut key: Vec = vec![]; for c_name in &table_cache_schema.partition_key { let column_name = c_name.to_lowercase(); + debug!("processing partition key segment: {}", column_name); match query_values.get(column_name.as_str()) { None => { return Err(CacheableState::Skip(format!( @@ -470,6 +517,8 @@ fn build_query_redis_key_from_value_map( column_name ))); } + debug!("extending key with segment {} value {}", column_name, relation_elements[0].value); + if !key.is_empty() { key.push(b':') } @@ -498,25 +547,17 @@ fn build_query_redis_key_from_value_map( column_name ))); } + debug!("extending key with segment {} value {}", column_name, relation_elements[0].value); + if !key.is_empty() { key.push(b':') } key.extend(relation_elements[0].value.to_string().as_bytes()); - /* - let mut my_elements = relation_elements.clone(); - - my_elements.sort(); - for relation_element in my_elements { - if !redis_key.is_empty() { - redis_key.extend(':'); - } - redis_key.extend( relation_element.value.to_string().as_bytes() ); - } - */ + } } } - Ok(BytesMut::from(key.as_slice())) + Ok(BytesMut::from(key.as_slice()).freeze()) } /// build the redis key for the query. @@ -526,7 +567,7 @@ fn build_query_redis_hash_from_value_map( table_cache_schema: &TableCacheSchema, query_values: &BTreeMap>, select: &Select, -) -> Result { +) -> Result { let mut my_values = query_values.clone(); for c_name in &table_cache_schema.partition_key { let column_name = c_name.to_lowercase(); @@ -567,7 +608,7 @@ fn build_query_redis_hash_from_value_map( .as_str(), ); - Ok(BytesMut::from(str.as_str())) + Ok(BytesMut::from(str.as_str()).freeze()) } fn populate_value_map_from_where_clause( @@ -588,7 +629,7 @@ fn populate_value_map_from_where_clause( fn build_redis_key_from_cql3( statement: &CassandraStatement, table_cache_schema: &TableCacheSchema, -) -> Result<(BytesMut, BytesMut), CacheableState> { +) -> Result<(Bytes, Bytes), CacheableState> { let mut value_map: BTreeMap> = BTreeMap::new(); match statement { CassandraStatement::Select(select) => { @@ -616,14 +657,14 @@ fn build_redis_key_from_cql3( } Ok(( build_query_redis_key_from_value_map(table_cache_schema, &value_map)?, - BytesMut::new(), + Bytes::new(), )) } CassandraStatement::Update(update) => { populate_value_map_from_where_clause(&mut value_map, &update.where_clause); Ok(( build_query_redis_key_from_value_map(table_cache_schema, &value_map)?, - BytesMut::new(), + Bytes::new(), )) } _ => unreachable!( @@ -644,7 +685,7 @@ impl Transform for SimpleRedisCache { })) = m.frame() { for cql_statement in &query.statements { - info!("cache transform processing {}", cql_statement); + debug!("cache transform processing {}", cql_statement); match cql_statement.get_query_type() { QueryType::Read => {} QueryType::Write => read_cache = false, @@ -657,7 +698,7 @@ impl Transform for SimpleRedisCache { read_cache = false; } } - info!("cache transform read_cache:{} ", read_cache); + debug!("cache transform read_cache:{} ", read_cache); // If there are no write queries (all queries are reads) we can read the cache if read_cache { @@ -665,12 +706,12 @@ impl Transform for SimpleRedisCache { Ok(cr) => return Ok(cr), Err(inner_state) => match &inner_state { CacheableState::Skip(reason) => { - tracing::info!("Cache skipped: {} ", reason); + info!("Cache skipped: {} ", reason); self.execute_upstream_and_process_result(message_wrapper) .await } CacheableState::Err(reason) => { - tracing::error!("Cache failed: {} ", reason); + error!("Cache failed: {} ", reason); message_wrapper.call_next_transform().await } _ => { @@ -710,9 +751,10 @@ mod test { build_redis_key_from_cql3, SimpleRedisCache, TableCacheSchema, }; use crate::transforms::{Transform, Transforms}; - use bytes::BytesMut; + use bytes::{Bytes, BytesMut}; use cql3_parser::cassandra_statement::CassandraStatement; use std::collections::HashMap; + use metrics::register_counter; fn build_query(query_string: &str) -> CassandraStatement { let cql = CQL::parse_from_string(query_string); @@ -720,6 +762,11 @@ mod test { cql.statements[0].statement.clone() } + #[test] + fn test_build_keys() { + + } + #[test] fn equal_test() { let table_cache_schema = TableCacheSchema { @@ -733,8 +780,8 @@ mod test { .ok() .unwrap(); - assert_eq!(BytesMut::from("1:123:965"), redis_key); - assert_eq!(BytesMut::from("* WHERE "), hash_key); + assert_eq!(Bytes::from("1:123:965"), redis_key); + assert_eq!(Bytes::from("* WHERE "), hash_key); } #[test] @@ -847,18 +894,18 @@ mod test { #[test] fn single_pk_only_test() { let table_cache_schema = TableCacheSchema { - partition_key: vec!["z".to_string()], + partition_key: vec!["id".to_string()], range_key: vec![], }; - let ast = build_query("SELECT * FROM foo WHERE z = 1"); + let ast = build_query("SELECT id, x, name FROM test_cache_keyspace_simple.test_table WHERE id=1"); let (redis_key, hash_key) = build_redis_key_from_cql3(&ast, &table_cache_schema) .ok() .unwrap(); assert_eq!(BytesMut::from("1"), redis_key); - assert_eq!(BytesMut::from("* WHERE "), hash_key); + assert_eq!(BytesMut::from("id, x, name WHERE "), hash_key); } #[test] @@ -906,10 +953,12 @@ mod test { #[tokio::test] async fn test_validate_invalid_chain() { + let missed_requests = register_counter!("cache_miss"); let chain = TransformChain::new(vec![], "test-chain".to_string()); let transform = SimpleRedisCache { cache_chain: chain, caching_schema: HashMap::new(), + missed_requests }; assert_eq!( @@ -924,6 +973,7 @@ mod test { #[tokio::test] async fn test_validate_valid_chain() { + let missed_requests = register_counter!("cache_miss"); let chain = TransformChain::new( vec![ Transforms::DebugPrinter(DebugPrinter::new()), @@ -935,6 +985,7 @@ mod test { let transform = SimpleRedisCache { cache_chain: chain, caching_schema: HashMap::new(), + missed_requests, }; assert_eq!(transform.validate(), Vec::::new()); diff --git a/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs b/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs index e844afaab..ddc6d0e27 100644 --- a/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs +++ b/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs @@ -2,8 +2,12 @@ use crate::helpers::cassandra::{assert_query_result, run_query, ResultValue}; use crate::helpers::ShotoverManager; use cassandra_cpp::{stmt, Batch, BatchType, Error, ErrorKind, Session}; use futures::future::{join_all, try_join_all}; +use metrics::Recorder; use serial_test::serial; use test_helpers::docker_compose::DockerCompose; +use metrics::counter; +use metrics_util::debugging::DebuggingRecorder; + mod keyspace { use crate::helpers::cassandra::{ @@ -1068,13 +1072,50 @@ mod cache { use cassandra_cpp::Session; use redis::Commands; use std::collections::HashSet; + use metrics::counter; + use metrics_util::debugging::{DebugValue, Snapshotter}; + use tracing_log::log::info; + + pub fn test(cassandra_session: &Session, redis_connection: &mut redis::Connection, snapshotter : &Snapshotter) { + test_batch_insert(cassandra_session, redis_connection, snapshotter); + test_simple(cassandra_session, redis_connection, snapshotter); + } - pub fn test(cassandra_session: &Session, redis_connection: &mut redis::Connection) { - test_batch_insert(cassandra_session, redis_connection); - test_simple(cassandra_session, redis_connection); + /// gets the current miss count from the cache instrumentation. + fn get_cache_miss_value(snapshotter: &Snapshotter) -> u64 { + let mut result = 0 as u64; + for (x,_,_,v) in snapshotter.snapshot().into_vec().iter() { + if let DebugValue::Counter(vv) = v { + if x.key().name().eq( "cache_miss") { + //return *vv; + info!( "Cache value: {}", vv ); + if *vv > result { + result = *vv; + } + } + } + } + result } - fn test_batch_insert(cassandra_session: &Session, redis_connection: &mut redis::Connection) { + /// The first time a query hits the cache it should not be found, the second time it should. + /// This function verifies that case by utilizing the cache miss instrumentation. + fn double_query(snapshotter: &Snapshotter, session: &Session, query: &str, expected_rows: &[&[ResultValue]]) { + let mut before = get_cache_miss_value(snapshotter); + // first query should miss the cache + assert_query_result(session, query, expected_rows); + let mut after = get_cache_miss_value(snapshotter); + assert_eq!( before+1, get_cache_miss_value(snapshotter), "first {}", query ); + + let before = after; + assert_query_result(session, query, expected_rows); + let after = get_cache_miss_value(snapshotter); + assert_eq!( before, after, "second {}", query ); + } + + fn test_batch_insert(cassandra_session: &Session, redis_connection: &mut redis::Connection, snapshotter : &Snapshotter) { + + redis::cmd("FLUSHDB").execute(redis_connection); run_query(cassandra_session, "CREATE KEYSPACE test_cache_keyspace_batch_insert WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };"); @@ -1091,7 +1132,8 @@ mod cache { APPLY BATCH;"#, ); - // TODO: SELECTS without a WHERE do not hit cache + // selects without where clauses do not hit the cache + let mut before = get_cache_miss_value(snapshotter); assert_query_result( cassandra_session, "SELECT id, x, name FROM test_cache_keyspace_batch_insert.test_table", @@ -1113,30 +1155,45 @@ mod cache { ], ], ); + assert_eq!( before, get_cache_miss_value(snapshotter)); + // query against the primary key - assert_query_result( + double_query( &snapshotter, cassandra_session, "SELECT id, x, name FROM test_cache_keyspace_batch_insert.test_table WHERE id=1", - &[], + &[ + &[ + ResultValue::Int(1), + ResultValue::Int(11), + ResultValue::Varchar("foo".into()), + ], + ], ); - // query against some other field + let before = get_cache_miss_value(snapshotter); + // queries without key are not cached assert_query_result( cassandra_session, "SELECT id, x, name FROM test_cache_keyspace_batch_insert.test_table WHERE x=11 ALLOW FILTERING", - &[], + &[&[ + ResultValue::Int(1), + ResultValue::Int(11), + ResultValue::Varchar("foo".into()), + ],], ); + assert_eq!( before, get_cache_miss_value(snapshotter)); // Insert a dummy key to ensure the keys command is working correctly, we can remove this later. redis_connection .set::<&str, i32, ()>("dummy_key", 1) .unwrap(); - let result: Vec = redis_connection.keys("*").unwrap(); - assert_eq!(result, ["dummy_key".to_string()]); + let mut result: Vec = redis_connection.keys("*").unwrap(); + result.sort(); + assert_eq!(result, [ "1".to_string(), "dummy_key".to_string(),]); } - fn test_simple(cassandra_session: &Session, redis_connection: &mut redis::Connection) { + fn test_simple(cassandra_session: &Session, redis_connection: &mut redis::Connection, snapshotter : &Snapshotter) { redis::cmd("FLUSHDB").execute(redis_connection); run_query(cassandra_session, "CREATE KEYSPACE test_cache_keyspace_simple WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };"); @@ -1158,7 +1215,8 @@ mod cache { "INSERT INTO test_cache_keyspace_simple.test_table (id, x, name) VALUES (3, 13, 'baz');", ); - // TODO: SELECTS without a WHERE do not hit cache + // selects without where clauses do not hit the cache + let mut before = get_cache_miss_value(snapshotter); assert_query_result( cassandra_session, "SELECT id, x, name FROM test_cache_keyspace_simple.test_table", @@ -1180,29 +1238,70 @@ mod cache { ], ], ); + assert_eq!(before, get_cache_miss_value(snapshotter)); // query against the primary key - assert_query_result( - cassandra_session, - "SELECT id, x, name FROM test_cache_keyspace_simple.test_table WHERE id=1", - &[], + double_query(&snapshotter, + cassandra_session, + "SELECT id, x, name FROM test_cache_keyspace_simple.test_table WHERE id=1", + &[ + &[ + ResultValue::Int(1), + ResultValue::Int(11), + ResultValue::Varchar("foo".into()), + ], + ], + ); + + // ensure key 2 and 3 are also loaded + double_query(&snapshotter, + cassandra_session, + "SELECT id, x, name FROM test_cache_keyspace_simple.test_table WHERE id=2", + &[ + &[ + ResultValue::Int(2), + ResultValue::Int(12), + ResultValue::Varchar("bar".into()), + ], + ], + ); + + double_query(&snapshotter, + cassandra_session, + "SELECT id, x, name FROM test_cache_keyspace_simple.test_table WHERE id=3", + &[ + &[ + ResultValue::Int(3), + ResultValue::Int(13), + ResultValue::Varchar("baz".into()), + ], + ], ); - // query against some other field + + // query without primary key does not hit the cache + let before = get_cache_miss_value(snapshotter); assert_query_result( cassandra_session, "SELECT id, x, name FROM test_cache_keyspace_simple.test_table WHERE x=11 ALLOW FILTERING", - &[], + &[ + &[ + ResultValue::Int(1), + ResultValue::Int(11), + ResultValue::Varchar("foo".into()), + ], + ], ); + assert_eq!(before, get_cache_miss_value(snapshotter)); let result: HashSet = redis_connection.keys("*").unwrap(); let expected: HashSet = ["1", "2", "3"].into_iter().map(|x| x.to_string()).collect(); assert_eq!(result, expected); - assert_sorted_set_equals(redis_connection, "1", &["1:11", "1:'foo'"]); - assert_sorted_set_equals(redis_connection, "2", &["2:12", "2:'bar'"]); - assert_sorted_set_equals(redis_connection, "3", &["3:13", "3:'baz'"]); + assert_sorted_set_equals(redis_connection, "1", &["id, x, name WHERE "]); + assert_sorted_set_equals(redis_connection, "2", &["id, x, name WHERE "]); + assert_sorted_set_equals(redis_connection, "3", &["id, x, name WHERE "]); } fn assert_sorted_set_equals( @@ -1212,8 +1311,8 @@ mod cache { ) { let expected_values: HashSet = expected_values.iter().map(|x| x.to_string()).collect(); - let values = redis_connection - .zrange::<&str, HashSet>(key, 0, -1) + let values : HashSet = redis_connection + .hkeys( key ) .unwrap(); assert_eq!(values, expected_values) } @@ -1462,10 +1561,17 @@ fn test_source_tls_and_single_tls() { #[test] #[serial] fn test_cassandra_redis_cache() { + let recorder =DebuggingRecorder::new(); + let rec = &recorder; + let snapshotter = recorder.snapshotter(); + let result = recorder.install(); + if result.is_err() { + assert!( false, "{:?}", result.err() ); + } let _compose = DockerCompose::new("example-configs/cassandra-redis-cache/docker-compose.yml"); let shotover_manager = - ShotoverManager::from_topology_file("example-configs/cassandra-redis-cache/topology.yaml"); + ShotoverManager::from_topology_file_without_observability("example-configs/cassandra-redis-cache/topology.yaml"); let mut redis_connection = shotover_manager.redis_connection(6379); let connection = shotover_manager.cassandra_connection("127.0.0.1", 9042); @@ -1474,7 +1580,7 @@ fn test_cassandra_redis_cache() { table::test(&connection); udt::test(&connection); functions::test(&connection); - cache::test(&connection, &mut redis_connection); + cache::test(&connection, &mut redis_connection, &snapshotter); prepared_statements::test(&connection); test_batch_statements(&connection); } From 3a83dfaab6e8a7a7c05eb47cc8b2d1084354d4df Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Wed, 20 Apr 2022 09:26:37 +0100 Subject: [PATCH 24/60] fixed up cargo issues --- shotover-proxy/src/frame/cassandra.rs | 4 +- shotover-proxy/src/transforms/mod.rs | 4 +- shotover-proxy/src/transforms/redis/cache.rs | 178 ++++++++++------- .../cassandra_int_tests/basic_driver_tests.rs | 182 ++++++++++-------- shotover-proxy/tests/helpers/cassandra.rs | 4 +- 5 files changed, 225 insertions(+), 147 deletions(-) diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index 828528ccf..161c3946c 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -34,7 +34,7 @@ use std::io::Cursor; use std::net::IpAddr; use std::num::NonZeroU32; use std::str::FromStr; -use tracing::info; +use tracing::debug; use uuid::Uuid; use crate::message::QueryType::PubSubMessage; @@ -926,7 +926,7 @@ impl CQL { /// the CassandraAST handles multiple queries in a string separated by semi-colons: `;` however /// CQL only stores one query so this method only returns the first one if there are multiples. pub fn parse_from_string(cql_query_str: &str) -> Self { - info!("parse_from_string: {}", cql_query_str); + debug!("parse_from_string: {}", cql_query_str); let ast = CassandraAST::new(cql_query_str); let mut vec = Vec::with_capacity(ast.statements.len()); diff --git a/shotover-proxy/src/transforms/mod.rs b/shotover-proxy/src/transforms/mod.rs index a72d5756c..bc7bb25d0 100644 --- a/shotover-proxy/src/transforms/mod.rs +++ b/shotover-proxy/src/transforms/mod.rs @@ -304,7 +304,7 @@ pub async fn build_chain_from_config( } use std::slice::IterMut; -use tracing::info; +use tracing::debug; /// The [`Wrapper`] struct is passed into each transform and contains a list of mutable references to the /// remaining transforms that will process the messages attached to this [`Wrapper`]. @@ -363,7 +363,7 @@ impl<'a> Wrapper<'a> { let transform_name = transform.get_name(); let chain_name = self.chain_name.clone(); - info!( + debug!( "call_next_transform calling {} {}", transform_name, chain_name ); diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index 2f431aa0f..d9d33d067 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -16,15 +16,13 @@ use cql3_parser::cassandra_statement::CassandraStatement; use cql3_parser::common::{Operand, RelationElement, RelationOperator}; use cql3_parser::select::{Named, Select, SelectElement}; use itertools::Itertools; +use metrics::{register_counter, Counter}; use serde::Deserialize; use std::collections::{BTreeMap, HashMap}; use std::fmt::{Display, Formatter}; use std::io::Cursor; use std::ops::Deref; -use cql3_parser::common_drop::CommonDrop; -use serde_yaml::Index; -use tracing_log::log::{trace, debug, info, warn, error}; -use metrics::{counter, Counter, register_counter}; +use tracing_log::log::{debug, error, info, trace, warn}; /* Uses redis as a cache. Data is stored in Redis as a Hash. @@ -57,12 +55,24 @@ enum CacheableState { impl Display for CacheableState { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - CacheableState::Read(name) => { write!(f, "Reading {}", name )} - CacheableState::Update(name) => { write!(f, "Updating {}", name )} - CacheableState::Delete(name) => { write!(f, "Deleting {}", name )} - CacheableState::Drop(name) => { write!(f, "Dropping {}", name)} - CacheableState::Skip(txt) => { write!(f, "Skipping due to: {}", txt )} - CacheableState::Err(txt) => { write!(f, "Error due to: {}", txt )} + CacheableState::Read(name) => { + write!(f, "Reading {}", name) + } + CacheableState::Update(name) => { + write!(f, "Updating {}", name) + } + CacheableState::Delete(name) => { + write!(f, "Deleting {}", name) + } + CacheableState::Drop(name) => { + write!(f, "Dropping {}", name) + } + CacheableState::Skip(txt) => { + write!(f, "Skipping due to: {}", txt) + } + CacheableState::Err(txt) => { + write!(f, "Error due to: {}", txt) + } } } } @@ -124,8 +134,14 @@ impl SimpleRedisCache { { match build_redis_key_from_cql3(statement, table_cache_schema) { Ok((redis_key, hash_key)) => { - trace!( "Redis key: {:?}", std::str::from_utf8( redis_key.deref())); - trace!( "Hash key: {:?}", std::str::from_utf8( hash_key.deref())); + trace!( + "Redis key: {:?}", + std::str::from_utf8(redis_key.deref()) + ); + trace!( + "Hash key: {:?}", + std::str::from_utf8(hash_key.deref()) + ); let commands_buffer = vec![ RedisFrame::BulkString("HGET".into()), RedisFrame::BulkString(redis_key), @@ -137,7 +153,7 @@ impl SimpleRedisCache { )); } Err(err_state) => { - warn!( "build_cache_query err: {}", err_state ); + warn!("build_cache_query err: {}", err_state); state = err_state; } } @@ -156,9 +172,9 @@ impl SimpleRedisCache { match state { CacheableState::Err(_) | CacheableState::Skip(_) => { - debug!( "build_cache_query: {}", state ); - return Err(state) - }, + debug!("build_cache_query: {}", state); + return Err(state); + } _ => {} } } @@ -198,26 +214,39 @@ impl SimpleRedisCache { "Cacheable Cassandra query must be only one statement".into(), )) } else if let Some(mut redis_response) = messages_redis_response_iter.next() { - match redis_response.frame() { Some(Frame::Redis(redis_frame)) => { match redis_frame { - - RedisFrame::SimpleString(s) => {Err(CacheableState::Err( "Redis returned a simple string".into() ))}, - RedisFrame::Error(e) => {return Err(CacheableState::Err(e.to_string()))}, - RedisFrame::Integer(i) => {Err(CacheableState::Err( "Redis returned an int value".into() ))}, + RedisFrame::SimpleString(_) => Err(CacheableState::Err( + "Redis returned a simple string".into(), + )), + RedisFrame::Error(e) => { + return Err(CacheableState::Err(e.to_string())) + } + RedisFrame::Integer(_) => Err(CacheableState::Err( + "Redis returned an int value".into(), + )), RedisFrame::BulkString(redis_bytes) => { // Redis response contains serialized version of result struct from CassandraOperation::Result( result ) let x = redis_bytes.iter().copied().collect_vec(); let mut cursor = Cursor::new(x.as_slice()); - let answer = CassandraResult::from_cursor(&mut cursor, Version::V4); + let answer = + CassandraResult::from_cursor(&mut cursor, Version::V4); if let Ok(result) = answer { Ok(result) } else { - Err(CacheableState::Err(answer.err().unwrap().to_string())) - }}, - RedisFrame::Array(a) => {Err(CacheableState::Err( "Redis returned an array value".into() ))}, - RedisFrame::Null => {self.missed_requests.increment(1); Err(CacheableState::Skip( "No cache results".into() ))} + Err(CacheableState::Err( + answer.err().unwrap().to_string(), + )) + } + } + RedisFrame::Array(_) => Err(CacheableState::Err( + "Redis returned an array value".into(), + )), + RedisFrame::Null => { + self.missed_requests.increment(1); + Err(CacheableState::Skip("No cache results".into())) + } } } @@ -274,8 +303,7 @@ impl SimpleRedisCache { debug!("read_from_cache called"); // build the cache query - let messages_redis_request = self - .build_cache_query(&mut cassandra_messages)?; + let messages_redis_request = self.build_cache_query(&mut cassandra_messages)?; // execute the cache query debug!("read_from_cache calling cache_chain.process_request"); @@ -285,15 +313,19 @@ impl SimpleRedisCache { Wrapper::new_with_chain_name(messages_redis_request, self.cache_chain.name.clone()), "clientdetailstodo".to_string(), ) - .await.map_err(|e| CacheableState::Err(format!("Redis error: {}", e)))?; + .await + .map_err(|e| CacheableState::Err(format!("Redis error: {}", e)))?; debug!("read_from_cache received OK from cache_chain.process_request"); self.unwrap_cache_response(messages_redis_response, cassandra_messages) } /// Clears the cache for the entire table - fn clear_table_cache(&self ) -> Option { - Some(Message::from_frame(Frame::Redis(RedisFrame::BulkString( "FLUSHDB".into() )))) + /// TODO make this drop only the specified keys not the entire cache + fn clear_table_cache(&self) -> Option { + Some(Message::from_frame(Frame::Redis(RedisFrame::BulkString( + "FLUSHDB".into(), + )))) } /// clear the cache for the single row specified by the redis_key @@ -308,7 +340,7 @@ impl SimpleRedisCache { { let commands_buffer: Vec = vec![ RedisFrame::BulkString("DEL".into()), - RedisFrame::BulkString(redis_key.into()), + RedisFrame::BulkString(redis_key), ]; Some(Message::from_frame(Frame::Redis(RedisFrame::Array( commands_buffer, @@ -381,8 +413,8 @@ impl SimpleRedisCache { let commands_buffer: Vec = vec![ RedisFrame::BulkString("HSET".into()), - RedisFrame::BulkString(redis_key.into()), - RedisFrame::BulkString(hash_key.into()), + RedisFrame::BulkString(redis_key), + RedisFrame::BulkString(hash_key), RedisFrame::BulkString(encoded.into()), ]; @@ -458,9 +490,7 @@ fn is_cacheable(cql_statement: &CQLStatement) -> CacheableState { CacheableState::Update(table_name.into()) } } - CassandraStatement::DropTable(_) => { - CacheableState::Drop( table_name.into() ) - } + CassandraStatement::DropTable(_) => CacheableState::Drop(table_name.into()), CassandraStatement::Update(update) => { if has_params || update.if_exists { CacheableState::Delete(table_name.into()) @@ -498,8 +528,10 @@ fn is_cacheable(cql_statement: &CQLStatement) -> CacheableState { fn build_query_redis_key_from_value_map( table_cache_schema: &TableCacheSchema, query_values: &BTreeMap>, + table_name: &str, ) -> Result { let mut key: Vec = vec![]; + key.extend(table_name.as_bytes()); for c_name in &table_cache_schema.partition_key { let column_name = c_name.to_lowercase(); debug!("processing partition key segment: {}", column_name); @@ -517,11 +549,11 @@ fn build_query_redis_key_from_value_map( column_name ))); } - debug!("extending key with segment {} value {}", column_name, relation_elements[0].value); - - if !key.is_empty() { - key.push(b':') - } + debug!( + "extending key with segment {} value {}", + column_name, relation_elements[0].value + ); + key.push(b':'); key.extend(relation_elements[0].value.to_string().as_bytes()); } } @@ -547,13 +579,13 @@ fn build_query_redis_key_from_value_map( column_name ))); } - debug!("extending key with segment {} value {}", column_name, relation_elements[0].value); + debug!( + "extending key with segment {} value {}", + column_name, relation_elements[0].value + ); - if !key.is_empty() { - key.push(b':') - } + key.push(b':'); key.extend(relation_elements[0].value.to_string().as_bytes()); - } } } @@ -635,7 +667,11 @@ fn build_redis_key_from_cql3( CassandraStatement::Select(select) => { populate_value_map_from_where_clause(&mut value_map, &select.where_clause); Ok(( - build_query_redis_key_from_value_map(table_cache_schema, &value_map)?, + build_query_redis_key_from_value_map( + table_cache_schema, + &value_map, + &select.table_name, + )?, build_query_redis_hash_from_value_map(table_cache_schema, &value_map, select)?, )) } @@ -656,14 +692,22 @@ fn build_redis_key_from_cql3( }; } Ok(( - build_query_redis_key_from_value_map(table_cache_schema, &value_map)?, + build_query_redis_key_from_value_map( + table_cache_schema, + &value_map, + &insert.table_name, + )?, Bytes::new(), )) } CassandraStatement::Update(update) => { populate_value_map_from_where_clause(&mut value_map, &update.where_clause); Ok(( - build_query_redis_key_from_value_map(table_cache_schema, &value_map)?, + build_query_redis_key_from_value_map( + table_cache_schema, + &value_map, + &update.table_name, + )?, Bytes::new(), )) } @@ -753,8 +797,8 @@ mod test { use crate::transforms::{Transform, Transforms}; use bytes::{Bytes, BytesMut}; use cql3_parser::cassandra_statement::CassandraStatement; - use std::collections::HashMap; use metrics::register_counter; + use std::collections::HashMap; fn build_query(query_string: &str) -> CassandraStatement { let cql = CQL::parse_from_string(query_string); @@ -763,9 +807,7 @@ mod test { } #[test] - fn test_build_keys() { - - } + fn test_build_keys() {} #[test] fn equal_test() { @@ -780,7 +822,7 @@ mod test { .ok() .unwrap(); - assert_eq!(Bytes::from("1:123:965"), redis_key); + assert_eq!(Bytes::from("foo:1:123:965"), redis_key); assert_eq!(Bytes::from("* WHERE "), hash_key); } @@ -797,7 +839,7 @@ mod test { .ok() .unwrap(); - assert_eq!(BytesMut::from("1"), redis_key); + assert_eq!(BytesMut::from("foo:1"), redis_key); assert!(hash_key.is_empty()); } @@ -813,7 +855,7 @@ mod test { .ok() .unwrap(); - assert_eq!(BytesMut::from("1:'yo'"), redis_key); + assert_eq!(BytesMut::from("foo:1:'yo'"), redis_key); assert!(hash_key.is_empty()); } @@ -829,7 +871,7 @@ mod test { let result = build_redis_key_from_cql3(&ast, &table_cache_schema); let (redis_key, hash_key) = result.ok().unwrap(); - assert_eq!(BytesMut::from("1"), redis_key); + assert_eq!(BytesMut::from("foo:1"), redis_key); assert!(hash_key.is_empty()); } @@ -870,7 +912,7 @@ mod test { .ok() .unwrap(); - assert_eq!(BytesMut::from("1"), redis_key); + assert_eq!(BytesMut::from("foo:1"), redis_key); assert_eq!(BytesMut::from("* WHERE x > 123 AND x < 999"), hash_key); } @@ -887,7 +929,7 @@ mod test { .ok() .unwrap(); - assert_eq!(BytesMut::from("1"), redis_key); + assert_eq!(BytesMut::from("foo:1"), redis_key); assert_eq!(BytesMut::from("* WHERE x >= 123 AND x <= 999"), hash_key); } @@ -898,13 +940,17 @@ mod test { range_key: vec![], }; - let ast = build_query("SELECT id, x, name FROM test_cache_keyspace_simple.test_table WHERE id=1"); + let ast = + build_query("SELECT id, x, name FROM test_cache_keyspace_simple.test_table WHERE id=1"); let (redis_key, hash_key) = build_redis_key_from_cql3(&ast, &table_cache_schema) .ok() .unwrap(); - assert_eq!(BytesMut::from("1"), redis_key); + assert_eq!( + BytesMut::from("test_cache_keyspace_simple.test_table:1"), + redis_key + ); assert_eq!(BytesMut::from("id, x, name WHERE "), hash_key); } @@ -921,7 +967,7 @@ mod test { .ok() .unwrap(); - assert_eq!(BytesMut::from("1:2"), key); + assert_eq!(BytesMut::from("foo:1:2"), key); assert_eq!(BytesMut::from("thing WHERE "), hash_key); } @@ -938,7 +984,7 @@ mod test { .ok() .unwrap(); - assert_eq!(BytesMut::from("1"), redis_key); + assert_eq!(BytesMut::from("foo:1"), redis_key); assert_eq!(BytesMut::from("* WHERE x >= 123"), hash_key); let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x <= 123"); @@ -947,7 +993,7 @@ mod test { .ok() .unwrap(); - assert_eq!(BytesMut::from("1"), redis_key); + assert_eq!(BytesMut::from("foo:1"), redis_key); assert_eq!(BytesMut::from("* WHERE x <= 123"), hash_key); } @@ -958,7 +1004,7 @@ mod test { let transform = SimpleRedisCache { cache_chain: chain, caching_schema: HashMap::new(), - missed_requests + missed_requests, }; assert_eq!( diff --git a/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs b/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs index ddc6d0e27..827035fc3 100644 --- a/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs +++ b/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs @@ -2,12 +2,9 @@ use crate::helpers::cassandra::{assert_query_result, run_query, ResultValue}; use crate::helpers::ShotoverManager; use cassandra_cpp::{stmt, Batch, BatchType, Error, ErrorKind, Session}; use futures::future::{join_all, try_join_all}; -use metrics::Recorder; +use metrics_util::debugging::DebuggingRecorder; use serial_test::serial; use test_helpers::docker_compose::DockerCompose; -use metrics::counter; -use metrics_util::debugging::DebuggingRecorder; - mod keyspace { use crate::helpers::cassandra::{ @@ -1070,25 +1067,28 @@ mod prepared_statements { mod cache { use crate::helpers::cassandra::{assert_query_result, run_query, ResultValue}; use cassandra_cpp::Session; + use metrics_util::debugging::{DebugValue, Snapshotter}; use redis::Commands; use std::collections::HashSet; - use metrics::counter; - use metrics_util::debugging::{DebugValue, Snapshotter}; use tracing_log::log::info; - pub fn test(cassandra_session: &Session, redis_connection: &mut redis::Connection, snapshotter : &Snapshotter) { + pub fn test( + cassandra_session: &Session, + redis_connection: &mut redis::Connection, + snapshotter: &Snapshotter, + ) { test_batch_insert(cassandra_session, redis_connection, snapshotter); test_simple(cassandra_session, redis_connection, snapshotter); } /// gets the current miss count from the cache instrumentation. fn get_cache_miss_value(snapshotter: &Snapshotter) -> u64 { - let mut result = 0 as u64; - for (x,_,_,v) in snapshotter.snapshot().into_vec().iter() { + let mut result = 0_u64; + for (x, _, _, v) in snapshotter.snapshot().into_vec().iter() { if let DebugValue::Counter(vv) = v { - if x.key().name().eq( "cache_miss") { + if x.key().name().eq("cache_miss") { //return *vv; - info!( "Cache value: {}", vv ); + info!("Cache value: {}", vv); if *vv > result { result = *vv; } @@ -1100,22 +1100,34 @@ mod cache { /// The first time a query hits the cache it should not be found, the second time it should. /// This function verifies that case by utilizing the cache miss instrumentation. - fn double_query(snapshotter: &Snapshotter, session: &Session, query: &str, expected_rows: &[&[ResultValue]]) { - let mut before = get_cache_miss_value(snapshotter); + fn double_query( + snapshotter: &Snapshotter, + session: &Session, + query: &str, + expected_rows: &[&[ResultValue]], + ) { + let before = get_cache_miss_value(snapshotter); // first query should miss the cache assert_query_result(session, query, expected_rows); - let mut after = get_cache_miss_value(snapshotter); - assert_eq!( before+1, get_cache_miss_value(snapshotter), "first {}", query ); + let after = get_cache_miss_value(snapshotter); + assert_eq!( + before + 1, + get_cache_miss_value(snapshotter), + "first {}", + query + ); let before = after; assert_query_result(session, query, expected_rows); let after = get_cache_miss_value(snapshotter); - assert_eq!( before, after, "second {}", query ); + assert_eq!(before, after, "second {}", query); } - fn test_batch_insert(cassandra_session: &Session, redis_connection: &mut redis::Connection, snapshotter : &Snapshotter) { - - + fn test_batch_insert( + cassandra_session: &Session, + redis_connection: &mut redis::Connection, + snapshotter: &Snapshotter, + ) { redis::cmd("FLUSHDB").execute(redis_connection); run_query(cassandra_session, "CREATE KEYSPACE test_cache_keyspace_batch_insert WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };"); @@ -1133,7 +1145,7 @@ mod cache { ); // selects without where clauses do not hit the cache - let mut before = get_cache_miss_value(snapshotter); + let before = get_cache_miss_value(snapshotter); assert_query_result( cassandra_session, "SELECT id, x, name FROM test_cache_keyspace_batch_insert.test_table", @@ -1155,20 +1167,18 @@ mod cache { ], ], ); - assert_eq!( before, get_cache_miss_value(snapshotter)); - + assert_eq!(before, get_cache_miss_value(snapshotter)); // query against the primary key - double_query( &snapshotter, + double_query( + snapshotter, cassandra_session, "SELECT id, x, name FROM test_cache_keyspace_batch_insert.test_table WHERE id=1", - &[ - &[ + &[&[ ResultValue::Int(1), ResultValue::Int(11), ResultValue::Varchar("foo".into()), - ], - ], + ]], ); let before = get_cache_miss_value(snapshotter); @@ -1182,7 +1192,7 @@ mod cache { ResultValue::Varchar("foo".into()), ],], ); - assert_eq!( before, get_cache_miss_value(snapshotter)); + assert_eq!(before, get_cache_miss_value(snapshotter)); // Insert a dummy key to ensure the keys command is working correctly, we can remove this later. redis_connection @@ -1190,10 +1200,20 @@ mod cache { .unwrap(); let mut result: Vec = redis_connection.keys("*").unwrap(); result.sort(); - assert_eq!(result, [ "1".to_string(), "dummy_key".to_string(),]); + assert_eq!( + result, + [ + "dummy_key".to_string(), + "test_cache_keyspace_batch_insert.test_table:1".to_string(), + ] + ); } - fn test_simple(cassandra_session: &Session, redis_connection: &mut redis::Connection, snapshotter : &Snapshotter) { + fn test_simple( + cassandra_session: &Session, + redis_connection: &mut redis::Connection, + snapshotter: &Snapshotter, + ) { redis::cmd("FLUSHDB").execute(redis_connection); run_query(cassandra_session, "CREATE KEYSPACE test_cache_keyspace_simple WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };"); @@ -1216,7 +1236,7 @@ mod cache { ); // selects without where clauses do not hit the cache - let mut before = get_cache_miss_value(snapshotter); + let before = get_cache_miss_value(snapshotter); assert_query_result( cassandra_session, "SELECT id, x, name FROM test_cache_keyspace_simple.test_table", @@ -1241,44 +1261,40 @@ mod cache { assert_eq!(before, get_cache_miss_value(snapshotter)); // query against the primary key - double_query(&snapshotter, - cassandra_session, - "SELECT id, x, name FROM test_cache_keyspace_simple.test_table WHERE id=1", - &[ - &[ - ResultValue::Int(1), - ResultValue::Int(11), - ResultValue::Varchar("foo".into()), - ], - ], + double_query( + snapshotter, + cassandra_session, + "SELECT id, x, name FROM test_cache_keyspace_simple.test_table WHERE id=1", + &[&[ + ResultValue::Int(1), + ResultValue::Int(11), + ResultValue::Varchar("foo".into()), + ]], ); // ensure key 2 and 3 are also loaded - double_query(&snapshotter, - cassandra_session, - "SELECT id, x, name FROM test_cache_keyspace_simple.test_table WHERE id=2", - &[ - &[ - ResultValue::Int(2), - ResultValue::Int(12), - ResultValue::Varchar("bar".into()), - ], - ], + double_query( + snapshotter, + cassandra_session, + "SELECT id, x, name FROM test_cache_keyspace_simple.test_table WHERE id=2", + &[&[ + ResultValue::Int(2), + ResultValue::Int(12), + ResultValue::Varchar("bar".into()), + ]], ); - double_query(&snapshotter, - cassandra_session, - "SELECT id, x, name FROM test_cache_keyspace_simple.test_table WHERE id=3", - &[ - &[ - ResultValue::Int(3), - ResultValue::Int(13), - ResultValue::Varchar("baz".into()), - ], - ], + double_query( + snapshotter, + cassandra_session, + "SELECT id, x, name FROM test_cache_keyspace_simple.test_table WHERE id=3", + &[&[ + ResultValue::Int(3), + ResultValue::Int(13), + ResultValue::Varchar("baz".into()), + ]], ); - // query without primary key does not hit the cache let before = get_cache_miss_value(snapshotter); assert_query_result( @@ -1295,13 +1311,31 @@ mod cache { assert_eq!(before, get_cache_miss_value(snapshotter)); let result: HashSet = redis_connection.keys("*").unwrap(); - let expected: HashSet = - ["1", "2", "3"].into_iter().map(|x| x.to_string()).collect(); + let expected: HashSet = [ + "test_cache_keyspace_simple.test_table:1", + "test_cache_keyspace_simple.test_table:2", + "test_cache_keyspace_simple.test_table:3", + ] + .into_iter() + .map(|x| x.to_string()) + .collect(); assert_eq!(result, expected); - assert_sorted_set_equals(redis_connection, "1", &["id, x, name WHERE "]); - assert_sorted_set_equals(redis_connection, "2", &["id, x, name WHERE "]); - assert_sorted_set_equals(redis_connection, "3", &["id, x, name WHERE "]); + assert_sorted_set_equals( + redis_connection, + "test_cache_keyspace_simple.test_table:1", + &["id, x, name WHERE "], + ); + assert_sorted_set_equals( + redis_connection, + "test_cache_keyspace_simple.test_table:2", + &["id, x, name WHERE "], + ); + assert_sorted_set_equals( + redis_connection, + "test_cache_keyspace_simple.test_table:3", + &["id, x, name WHERE "], + ); } fn assert_sorted_set_equals( @@ -1311,9 +1345,7 @@ mod cache { ) { let expected_values: HashSet = expected_values.iter().map(|x| x.to_string()).collect(); - let values : HashSet = redis_connection - .hkeys( key ) - .unwrap(); + let values: HashSet = redis_connection.hkeys(key).unwrap(); assert_eq!(values, expected_values) } } @@ -1561,17 +1593,17 @@ fn test_source_tls_and_single_tls() { #[test] #[serial] fn test_cassandra_redis_cache() { - let recorder =DebuggingRecorder::new(); - let rec = &recorder; + let recorder = DebuggingRecorder::new(); let snapshotter = recorder.snapshotter(); let result = recorder.install(); if result.is_err() { - assert!( false, "{:?}", result.err() ); + panic!("{:?}", result.err()); } let _compose = DockerCompose::new("example-configs/cassandra-redis-cache/docker-compose.yml"); - let shotover_manager = - ShotoverManager::from_topology_file_without_observability("example-configs/cassandra-redis-cache/topology.yaml"); + let shotover_manager = ShotoverManager::from_topology_file_without_observability( + "example-configs/cassandra-redis-cache/topology.yaml", + ); let mut redis_connection = shotover_manager.redis_connection(6379); let connection = shotover_manager.cassandra_connection("127.0.0.1", 9042); diff --git a/shotover-proxy/tests/helpers/cassandra.rs b/shotover-proxy/tests/helpers/cassandra.rs index 08b9ba1d6..800237ee1 100644 --- a/shotover-proxy/tests/helpers/cassandra.rs +++ b/shotover-proxy/tests/helpers/cassandra.rs @@ -1,6 +1,6 @@ use cassandra_cpp::{stmt, Cluster, Error, Session, Value, ValueType}; use ordered_float::OrderedFloat; -use tracing::info; +use tracing_log::log::debug; pub fn cassandra_connection(contact_points: &str, port: u16) -> Session { for contact_point in contact_points.split(',') { @@ -108,7 +108,7 @@ impl ResultValue { #[allow(unused)] pub fn execute_query(session: &Session, query: &str) -> Vec> { let statement = stmt!(query); - info!("executing query: {}", query); + debug!("executing query: {}", query); match session.execute(&statement).wait() { Ok(result) => result .into_iter() From d9d2dad7c18f05f4b33615779765a0b5b3fcac8f Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Wed, 20 Apr 2022 12:07:14 +0100 Subject: [PATCH 25/60] added cache documentation --- shotover-proxy/src/transforms/redis/cache.md | 123 +++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 shotover-proxy/src/transforms/redis/cache.md diff --git a/shotover-proxy/src/transforms/redis/cache.md b/shotover-proxy/src/transforms/redis/cache.md new file mode 100644 index 000000000..88b85265f --- /dev/null +++ b/shotover-proxy/src/transforms/redis/cache.md @@ -0,0 +1,123 @@ +# Cache Design Documentation + +## Overview +The redis cache design is intended to cache Cassandra queries into +a Redis based cache. It intercepts calls to Cassandra determines +if it can respond and if so intercepts the entire set of Messages. +If not, the messages are sent on for processing and any cacheable +results are processed on the return. + +## Cassandra mapping + +There are 38 Cassandra query commands of which we are concerned with +6: + + * Select + * Insert + * Update + * Delete + * Truncate + * Drop Table + +We initially process Select statements to see if we can respond to them. +We also process Select statements on their return from Cassandra to +populate the cache. + +Insert, Update, Delete methods clear the cache for specific rows or +potentially cause the entire table to be removed from the cache. + +Truncate and Drop Table calls cause all the table entries to be +removed from the cache. + +## Redis + +### Cache configuration + +The cache configuration is specified in the topology yaml file. +entries in the `caching_schema` define which tables are caching +candidates and defines the partition and range key. Combined the +partition and range key define the primary key. The partition key +should match the Cassandra partition key, and the range key should +match the Cassandra clustering key. This is not checked during +operation and is not strictly required, however, unexpected results +may occur if they do not match. + +In the following extract, `test_cache_keyspace_batch_insert.test_table` +is a fully qualified table name. That table has a single partition +key segment: the `id` column from the table. There is no range key +defined. + +```yaml + caching_schema: + test_cache_keyspace_batch_insert.test_table: + partition_key: [id] + range_key: [] + test_cache_keyspace_simple.test_table: + partition_key: [id] + range_key: [] +``` + +### Cache structure + +The redis cache is a single Redis source. All keys into the redis +source are defined by the fully qualified table name, the primary +key segments, and the range segments. + +Each key identifies a Hashset. The keys within the hash set +comprise the base column names in the select (not tha aliases) +followed by " WHERE " and then the filtering statements other than +the Cassandra key fields. + +for example `SELECT a, b, c as g FROM keyspace1.table2 WHERE e='foo' a[2]=3` +on where table2 is defined with a primary key of `e` will yield the redis key +`keyspace1.table2:'foo'` and a hash key of `a b c WHERE a[2]=3`. The entry +in the redis cache is a serialized version of the row metadata as well as the +row values as returned from the Cassandra call. + +### Instrumentation + +The cache uses the metrics crate to provide a counter called "cache_miss" that records +every attempt to read the cache where the cache did not have the requested value. It +does not count any statement that was rejected because it need not meet the requirements +for a cacheable statement. + +## Cacheable statements + +The cache system defines a CacheableState. The possible cacheable states are + * Read( table ) - Indicates that the statement reads a table and so is may be served from the cache or alternatively the cache shoud be updated. + * Update( table ) - Indicates that the statement performs an update on the table. Only effected rows should be removed. + * Delete( table ) - Indicates that the statement deletes a row or part of a row from the table. Only effected rows should be removed. + * Drop( table ) - Indicates tha the table should be dropped from the cache. + * Skip( reason ) - Indicates that the cache should not be queried/updated. The reason is a human readable statement of why that may appear in the logs. + * Err( reason ) - Indicates that an error has occurred in the caching system. The reason is a human readable statement that will appear in the logs. + +### Table of CacheableState results + +| Cassandra
Query | Read | Update | Delete | Skip | Error | +|---------------------|----------|---------|--------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------|-------| +| Select | default | | * has parameters | * allows filtering
* no `where` clause | | +| Insert | | default | * has parameters
* has `if not exists` clause | | | +| Update | | default | * has parameters
* has `if exists` clause
* if there are calculations in the assignment
* if the assigned colum uses an index. | | | +| Delete | | | not implemented | default | | +| Truncate | | | not implemented | default | | +| Drop Table | | | default - delete all | | | +| All others | | | | default | | + + +Once a statement is determined to be cacheable it is processed accordingly. During +processing the state may be reset. Specifically it may be reset to Skip or Error. + +## Process Flow + +The cache system sits on both the outbound and return data flows. On the outbound flow, +toward the data source, the cache determines if the message wrapper contains only +cacheable Cassandra statements and if so attempts to answer from the cache. If during +the read from the cache any statements fail to be answered the entire message wrapper +is forwarded down the chain and the results are processed. + +If the cache can not answer _ALL_ the original Cassandra queries the answers +from upstream are processed. If the upstream commands were successfull, the +results may modify the cache. Any command with a cacheable state of +Read, Update or Delete is processed again on return and the cache updated +appropriately. + From 1f891feb081a8a4d5884d3ca785ea072bf595bbd Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Wed, 20 Apr 2022 12:59:35 +0100 Subject: [PATCH 26/60] update cql3 dependency --- Cargo.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 114a90db7..c3f8db081 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -438,7 +438,7 @@ dependencies = [ [[package]] name = "cql3_parser" version = "0.0.1" -source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#34cebfc9972c8aab0ce2d31c894c9d338ccdbb06" +source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#d769cfcef25114d19c88f369536215763a3bd471" dependencies = [ "bigdecimal", "bytes", From 7309cf223e434fd8a9127439efa7b91736fbf751 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Wed, 20 Apr 2022 13:20:19 +0100 Subject: [PATCH 27/60] fixed cargo issues --- shotover-proxy/src/frame/cassandra.rs | 14 +++++++------- shotover-proxy/tests/helpers/cassandra.rs | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index 161c3946c..ed31769bc 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -945,12 +945,12 @@ impl CQL { } pub trait ToCassandraType { - fn from_string_value(&self, value: &str) -> Option; + fn from_string_value(value: &str) -> Option; fn as_cassandra_type(&self) -> Option; } impl ToCassandraType for Operand { - fn from_string_value(&self, value: &str) -> Option { + fn from_string_value(value: &str) -> Option { // check for string types if value.starts_with('\'') || value.starts_with("$$") { Some(CassandraType::Varchar(value.to_string())) @@ -975,14 +975,14 @@ impl ToCassandraType for Operand { fn as_cassandra_type(&self) -> Option { match self { - Operand::Const(value) => self.from_string_value(value), + Operand::Const(value) => Operand::from_string_value(value), Operand::Map(values) => Some(CassandraType::Map( values .iter() .map(|(key, value)| { ( - self.from_string_value(key).unwrap(), - self.from_string_value(value).unwrap(), + Operand::from_string_value(key).unwrap(), + Operand::from_string_value(value).unwrap(), ) }) .collect(), @@ -990,13 +990,13 @@ impl ToCassandraType for Operand { Operand::Set(values) => Some(CassandraType::Set( values .iter() - .filter_map(|value| self.from_string_value(value)) + .filter_map(|value| Operand::from_string_value(value)) .collect(), )), Operand::List(values) => Some(CassandraType::List( values .iter() - .filter_map(|value| self.from_string_value(value)) + .filter_map(|value| Operand::from_string_value(value)) .collect(), )), Operand::Tuple(values) => Some(CassandraType::Tuple( diff --git a/shotover-proxy/tests/helpers/cassandra.rs b/shotover-proxy/tests/helpers/cassandra.rs index 800237ee1..ae981d6d8 100644 --- a/shotover-proxy/tests/helpers/cassandra.rs +++ b/shotover-proxy/tests/helpers/cassandra.rs @@ -13,7 +13,7 @@ pub fn cassandra_connection(contact_points: &str, port: u16) -> Session { cluster.set_load_balance_round_robin(); let result = cluster.connect(); if let Some(err) = &result.as_ref().err() { - assert!(false, "{}", err); + panic!( "{}", err); } result.unwrap() } From 1ea5b6411e4c15c0d866dfa34308ce249c751f1e Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Wed, 20 Apr 2022 13:32:27 +0100 Subject: [PATCH 28/60] fixed cargo issues --- shotover-proxy/tests/helpers/cassandra.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shotover-proxy/tests/helpers/cassandra.rs b/shotover-proxy/tests/helpers/cassandra.rs index ae981d6d8..f73746ee2 100644 --- a/shotover-proxy/tests/helpers/cassandra.rs +++ b/shotover-proxy/tests/helpers/cassandra.rs @@ -13,7 +13,7 @@ pub fn cassandra_connection(contact_points: &str, port: u16) -> Session { cluster.set_load_balance_round_robin(); let result = cluster.connect(); if let Some(err) = &result.as_ref().err() { - panic!( "{}", err); + panic!("{}", err); } result.unwrap() } From f961aae82db91dc75cbaac8a93bbc626f0b59794 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Thu, 21 Apr 2022 07:11:56 +0100 Subject: [PATCH 29/60] updated cql parser --- Cargo.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index c3f8db081..536fa306b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -438,7 +438,7 @@ dependencies = [ [[package]] name = "cql3_parser" version = "0.0.1" -source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#d769cfcef25114d19c88f369536215763a3bd471" +source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#101e86e0c9b98241b4e216a0639ee269ec994248" dependencies = [ "bigdecimal", "bytes", From e1bdebc0d998b758cf050d00356f5d114613fe08 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Thu, 21 Apr 2022 08:24:32 +0100 Subject: [PATCH 30/60] modified test to verify decrypted data --- .../tests/cassandra_int_tests/basic_driver_tests.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs b/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs index 827035fc3..0a8950100 100644 --- a/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs +++ b/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs @@ -1385,9 +1385,10 @@ mod protect { "SELECT pk, cluster, col1, col2, col3 FROM test_protect_keyspace.test_table", ); if let ResultValue::Varchar(value) = &result[0][2] { - assert!(value.starts_with("{\"Ciphertext")); + assert_eq!( "I am gonna get encrypted!!", value); + //assert!(value.starts_with("{\"Ciphertext")); } else { - panic!("expectected 3rd column to be ResultValue::Varchar in {result:?}"); + panic!("expected 3rd column to be ResultValue::Varchar in {result:?}"); } // assert that data is encrypted on cassandra side From 1f4ae01b7e43f1c999400eadf3b2d8e9875f95aa Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Fri, 22 Apr 2022 06:48:48 +0100 Subject: [PATCH 31/60] changed to use FQName --- Cargo.lock | 68 ++--- shotover-proxy/Cargo.toml | 3 +- shotover-proxy/src/frame/cassandra.rs | 10 +- shotover-proxy/src/message/mod.rs | 2 +- .../src/transforms/cassandra/peers_rewrite.rs | 2 +- shotover-proxy/src/transforms/protect/mod.rs | 237 +++++++++++------- shotover-proxy/src/transforms/redis/cache.rs | 6 +- .../cassandra_int_tests/basic_driver_tests.rs | 2 +- 8 files changed, 197 insertions(+), 133 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 536fa306b..83f4af890 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -126,9 +126,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "backtrace" -version = "0.3.64" +version = "0.3.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e121dee8023ce33ab248d9ce1493df03c3b38a659b240096fcbd7048ff9c31f" +checksum = "11a17d453482a265fd5f8479f2a3f405566e6ca627837aaddb85af8b1ab8ef61" dependencies = [ "addr2line", "cc", @@ -290,8 +290,8 @@ checksum = "8234d29d30873ab5a41e3557b8515d3ecbaefb1ea5be579425b3b0074b6d0e40" [[package]] name = "cassandra-protocol" -version = "1.1.0" -source = "git+https://github.com/krojew/cdrs-tokio#61afd7f9fa897635a78a15a4e3e89864d16101f6" +version = "1.1.1" +source = "git+https://github.com/krojew/cdrs-tokio#c5e0cf444262ec811981af9486f1a269305ec3a8" dependencies = [ "arrayref", "bitflags", @@ -356,16 +356,16 @@ dependencies = [ [[package]] name = "clap" -version = "3.1.8" +version = "3.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71c47df61d9e16dc010b55dba1952a57d8c215dbb533fd13cdd13369aac73b1c" +checksum = "3124f3f75ce09e22d1410043e1e24f2ecc44fad3afe4f08408f1f7663d68da2b" dependencies = [ "atty", "bitflags", "clap_derive", + "clap_lex", "indexmap", "lazy_static", - "os_str_bytes", "strsim", "termcolor", "textwrap 0.15.0", @@ -384,6 +384,15 @@ dependencies = [ "syn", ] +[[package]] +name = "clap_lex" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "189ddd3b5d32a70b35e7686054371742a937b0d99128e76dde6340210e966669" +dependencies = [ + "os_str_bytes", +] + [[package]] name = "combine" version = "4.6.3" @@ -438,7 +447,7 @@ dependencies = [ [[package]] name = "cql3_parser" version = "0.0.1" -source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#101e86e0c9b98241b4e216a0639ee269ec994248" +source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=create_FQName#e6d5b9d352d0212ebbeccd427422efca1c792ea6" dependencies = [ "bigdecimal", "bytes", @@ -1154,9 +1163,9 @@ dependencies = [ [[package]] name = "ipnet" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35e70ee094dc02fd9c13fdad4940090f22dbd6ac7c9e7094a46cf0232a50bc7c" +checksum = "879d54834c8c76457ef4293a689b2a8c59b076067ad77b15efafbb05f92a592b" [[package]] name = "itertools" @@ -1209,9 +1218,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.122" +version = "0.2.124" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec647867e2bf0772e28c8bcde4f0d19a9216916e890543b5a03ed8ef27b8f259" +checksum = "21a41fed9d98f27ab1c6d161da622a4fa35e8a54a8adc24bbf3ddd0ef70b0e50" [[package]] name = "libloading" @@ -1399,12 +1408,11 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.4.4" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a92518e98c078586bc6c934028adcca4c92a53d6a958196de835170a01d84e4b" +checksum = "d2b29bd4bc3f33391105ebee3589c19197c4271e3e5a9ec9bfe8127eeff8f082" dependencies = [ "adler", - "autocfg", ] [[package]] @@ -1646,9 +1654,9 @@ dependencies = [ [[package]] name = "object" -version = "0.27.1" +version = "0.28.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67ac1d3f9a1d3616fd9a60c8d74296f22406a238b6a72f5cc1e6f314df4ffbf9" +checksum = "40bec70ba014595f99f7aa110b84331ffe1ee9aece7fe6f387cc7e3ecda4d456" dependencies = [ "memchr", ] @@ -1729,9 +1737,6 @@ name = "os_str_bytes" version = "6.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e22443d1643a904602595ba1cd8f7d896afe56d26712531c5ff73a15b2fbf64" -dependencies = [ - "memchr", -] [[package]] name = "parking_lot" @@ -1783,9 +1788,9 @@ dependencies = [ [[package]] name = "pcap" -version = "0.9.1" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "742d671d505d54e83924cc200e08442111f0ace8ebee3b9dc31cd9710cb301cb" +checksum = "42d1868a121e5f1d78134be7148b778d5352f28b7be30e7f993c3439671f0190" dependencies = [ "errno", "libc", @@ -2047,9 +2052,9 @@ dependencies = [ [[package]] name = "rayon" -version = "1.5.1" +version = "1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06aca804d41dbc8ba42dfd964f0d01334eceb64314b9ecf7c5fad5188a06d90" +checksum = "fd249e82c21598a9a426a4e00dd7adc1d640b22445ec8545feef801d1a74c221" dependencies = [ "autocfg", "crossbeam-deque", @@ -2059,14 +2064,13 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.9.1" +version = "1.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d78120e2c850279833f1dd3582f730c4ab53ed95aeaaaa862a2a5c71b1656d8e" +checksum = "9f51245e1e62e1f1629cbfec37b5793bbabcaeb90f30e94d2ba03564687353e4" dependencies = [ "crossbeam-channel", "crossbeam-deque", "crossbeam-utils", - "lazy_static", "num_cpus", ] @@ -2545,7 +2549,7 @@ dependencies = [ "cached", "cassandra-cpp", "cassandra-protocol", - "clap 3.1.8", + "clap 3.1.10", "cql3_parser", "crc16", "criterion", @@ -3005,9 +3009,9 @@ checksum = "360dfd1d6d30e05fda32ace2c8c70e9c0a9da713275777f5a4dbb8a1893930c6" [[package]] name = "tracing" -version = "0.1.33" +version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80b9fa4360528139bc96100c160b7ae879f5567f49f1782b0b02035b0358ebf3" +checksum = "5d0ecdcb44a79f0fe9844f0c4f33a342cbcbb5117de8001e6ba0dc2351327d09" dependencies = [ "cfg-if", "pin-project-lite", @@ -3039,9 +3043,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.24" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90442985ee2f57c9e1b548ee72ae842f4a9a20e3f417cc38dbc5dc684d9bb4ee" +checksum = "f54c8ca710e81886d498c2fd3331b56c93aa248d49de2222ad2742247c60072f" dependencies = [ "lazy_static", "valuable", diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml index 18b4bb94f..7bab5492a 100644 --- a/shotover-proxy/Cargo.toml +++ b/shotover-proxy/Cargo.toml @@ -9,6 +9,7 @@ license = "Apache-2.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] +default = ["alpha-transforms"] # Include WIP alpha transforms in the public API alpha-transforms = [] @@ -42,7 +43,7 @@ anyhow = "1.0.31" # Parsers sqlparser = "0.16" -cql3_parser = { git = "https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git", branch = "main" } +cql3_parser = { git = "https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git", branch = "create_FQName" } serde = { version = "1.0.111", features = ["derive"] } serde_json = "1.0" serde_yaml = "0.8.21" diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index ed31769bc..fd4fd174d 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -22,7 +22,7 @@ use cassandra_protocol::types::value::Value; use cassandra_protocol::types::{CBytes, CBytesShort, CInt, CLong}; use cql3_parser::cassandra_ast::CassandraAST; use cql3_parser::cassandra_statement::CassandraStatement; -use cql3_parser::common::{Operand, RelationElement}; +use cql3_parser::common::{Operand, RelationElement, FQName}; use cql3_parser::insert::InsertValues; use cql3_parser::update::AssignmentOperator; use itertools::Itertools; @@ -285,13 +285,13 @@ impl CassandraFrame { } /// returns a list of table names from the CassandraOperation - pub fn get_table_names(&self) -> Vec { + pub fn get_table_names(&self) -> Vec<&FQName> { let mut result = vec![]; match &self.operation { CassandraOperation::Query { query: cql, .. } => { for cql_statement in &cql.statements { if let Some(name) = CQLStatement::get_table_name(&cql_statement.statement) { - result.push(name.into()); + result.push(name); } } } @@ -302,7 +302,7 @@ impl CassandraFrame { if let Some(name) = CQLStatement::get_table_name(&cql_statement.statement) { - result.push(name.into()); + result.push(name); } } } @@ -587,7 +587,7 @@ impl CQLStatement { } /// returns the table name specified in the command if one is present. - pub fn get_table_name(statement: &CassandraStatement) -> Option<&String> { + pub fn get_table_name(statement: &CassandraStatement) -> Option<&FQName> { match statement { CassandraStatement::AlterTable(t) => Some(&t.name), CassandraStatement::CreateIndex(i) => Some(&i.table), diff --git a/shotover-proxy/src/message/mod.rs b/shotover-proxy/src/message/mod.rs index 99e80da16..d9a29d792 100644 --- a/shotover-proxy/src/message/mod.rs +++ b/shotover-proxy/src/message/mod.rs @@ -177,7 +177,7 @@ impl Message { /// None if the statements do not contain table names. pub fn get_table_names(&mut self) -> Vec { match self.frame() { - Some(Frame::Cassandra(cassandra)) => cassandra.get_table_names(), + Some(Frame::Cassandra(cassandra)) => cassandra.get_table_names().iter().map( |n| n.to_string()).collect_vec(), Some(Frame::Redis(_)) => unimplemented!(), Some(Frame::None) => vec![], _ => unreachable!(), diff --git a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs index 8fb9e8225..60770e3a3 100644 --- a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs @@ -77,7 +77,7 @@ fn extract_native_port_column(message: &mut Message) -> Vec { let statement = &cql_statement.statement; if let CassandraStatement::Select(select) = &statement { if let Some(table_name) = CQLStatement::get_table_name(statement) { - if table_name.eq("system.peers_v2") { + if table_name.to_string().eq("system.peers_v2") { select .columns .iter() diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index d5d133396..3b22ed23b 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -9,9 +9,9 @@ use anyhow::Result; use async_trait::async_trait; use bytes::Bytes; use cql3_parser::cassandra_statement::CassandraStatement; -use cql3_parser::common::Operand; +use cql3_parser::common::{FQName, Operand}; use cql3_parser::insert::InsertValues; -use cql3_parser::select::SelectElement; +use cql3_parser::select::{Select, SelectElement}; use serde::{Deserialize, Serialize}; use sodiumoxide::crypto::secretbox; use sodiumoxide::crypto::secretbox::{Key, Nonce}; @@ -33,6 +33,59 @@ pub struct Protect { key_id: String, } +impl Protect { + fn get_protected_columns(&self, table_name : &FQName ) -> Option<&Vec> { + // TODO replace "" with cached keyspace name + if let Some(tables) = self.keyspace_table_columns.get(table_name.extract_keyspace( "")) + { + tables.get(&table_name.name) + } else { + None + } + } + + /// processes the select statement to modify the rows. returns true if the rows were modified + async fn process_select(&self, select : &Select, columns : &Vec, rows : &mut Vec>) -> Result { + let mut modified = false; + + let positions: Vec = select + .columns + .iter() + .enumerate() + .filter_map(|(i, col)| { + if let SelectElement::Column(named) = col { + if columns.contains(&named.name) { + Some(i) + } else { + None + } + } else { + None + } + }) + .collect(); + for row in &mut *rows { + for index in &positions { + if let Some(message_value) = row.get_mut(*index) { + let protected = Protected::extract_result(message_value); + + let new_value: MessageValue = protected.unwrap() + .unprotect( + &self.key_source, + &self.key_id, + ) + .await?; + *message_value = new_value; + modified = true; + } + } + } + + Ok(modified) + } + +} + #[derive(Clone)] pub struct KeyMaterial { pub ciphertext_blob: Bytes, @@ -86,8 +139,8 @@ impl From for MessageValue { "tried to move unencrypted value to plaintext without explicitly calling decrypt" ), Protected::Ciphertext { .. } => { - MessageValue::Bytes(Bytes::from(serde_json::to_vec(&p).unwrap())) - //MessageValue::Bytes(Bytes::from(bincode::serialize(&p).unwrap())) + //MessageValue::Bytes(Bytes::from(serde_json::to_vec(&p).unwrap())) + MessageValue::Bytes(Bytes::from(bincode::serialize(&p).unwrap())) } } } @@ -99,28 +152,48 @@ impl From<&Protected> for Operand { Protected::Plaintext(_) => panic!( "tried to move unencrypted value to plaintext without explicitly calling decrypt" ), - Protected::Ciphertext { .. } => Operand::Const(format!( - "0X{}", - hex::encode(serde_json::to_vec(&p).unwrap()) - )), + Protected::Ciphertext { .. } => Operand::Const(format!("'{}'", hex::encode(serde_json::to_vec(&p).unwrap()))), } } } impl Protected { - pub async fn from_encrypted_bytes_value(value: &MessageValue) -> Result { - match value { - MessageValue::Bytes(b) => { - // let protected_something: Protected = serde_json::from_slice(b.bytes())?; - let protected_something: Protected = bincode::deserialize(b)?; - Ok(protected_something) - } - _ => Err(anyhow!( + fn extract_result(value : &MessageValue ) -> Result { + match value { + MessageValue::Bytes(b) => { + // let protected_something: Protected = serde_json::from_slice(b.bytes())?; + let protected_something = bincode::deserialize(b); + if protected_something.is_err() { + Err(anyhow!( "{:?}", protected_something.err())) + } else { + Ok(protected_something.unwrap()) + } + } + MessageValue::Varchar(s) => { + warn!("varchar {}", s ); + let mut hex_value = s.chars(); + hex_value.next(); + hex_value.next_back(); + + + let byte_value = hex::decode(hex_value.as_str()); + //let plain = decrypt( byte_value); + //let x = bincode::deserialize(&byte_value.unwrap()); + + let protected_something = bincode::deserialize(&byte_value.unwrap()); + if protected_something.is_err() { + Err(anyhow!( "{:?}", protected_something.err())) + } else { + Ok(protected_something.unwrap()) + } + } + _ => Err(anyhow!( "Could not get bytes to decrypt - wrong value type {:?}", value )), - } + } } + // TODO should this actually return self (we are sealing the plaintext value, but we don't swap out the plaintext?? pub async fn protect(self, key_management: &KeyManager, key_id: &str) -> Result { let sym_key = key_management @@ -157,7 +230,13 @@ impl Protected { let sym_key = key_management .cached_get_key(key_id.to_string(), Some(enc_dek), Some(kek_id)) .await?; - decrypt(cipher, nonce, &sym_key.plaintext) + let result = decrypt(cipher, nonce, &sym_key.plaintext); + if result.is_err() { + Err( anyhow!( "{}", result.err().unwrap() )) + } else { + Ok( result.unwrap()) + } + } } } @@ -227,6 +306,7 @@ async fn encrypt_columns( Ok(data_changed) } + #[async_trait] impl Transform for Protect { async fn transform<'a>(&'a mut self, mut message_wrapper: Wrapper<'a>) -> ChainResponse { @@ -241,19 +321,16 @@ impl Transform for Protect { { for cql_statement in &mut query.statements { let statement = &mut cql_statement.statement; - if let Some(table_name) = CQLStatement::get_table_name(statement) { - if let Some((_, tables)) = - self.keyspace_table_columns.get_key_value(table_name) - { - if let Some((_, columns)) = tables.get_key_value(table_name) { - data_changed = encrypt_columns( - statement, - columns, - &self.key_source, - &self.key_id, - ) - .await?; + if let Some(columns) = self.get_protected_columns( table_name ) { + data_changed = encrypt_columns( + statement, + &columns, + &self.key_source, + &self.key_id, + ).await?; + if data_changed { + warn!( "statement changed to {}", statement ); } } } @@ -271,69 +348,23 @@ impl Transform for Protect { for (response, request) in result.iter_mut().zip(original_messages.iter_mut()) { let mut invalidate_cache = false; if let Some(Frame::Cassandra(CassandraFrame { - operation: - CassandraOperation::Result(CassandraResult::Rows { - value: MessageValue::Rows(rows), - .. - }), - .. - })) = response.frame() - { + operation: CassandraOperation::Query { query, .. }, + .. + })) = request.frame() { if let Some(Frame::Cassandra(CassandraFrame { - operation: CassandraOperation::Query { query, .. }, - .. - })) = request.frame() - { + operation: + CassandraOperation::Result(CassandraResult::Rows { + value: MessageValue::Rows(rows), + .. + }), + .. + })) = response.frame() { for cql_statement in &mut query.statements { let statement = &mut cql_statement.statement; if let Some(table_name) = CQLStatement::get_table_name(statement) { - if let Some((_keyspace, tables)) = - self.keyspace_table_columns.get_key_value(table_name) - { - if let Some((_table, protect_columns)) = - tables.get_key_value(table_name) - { - if let CassandraStatement::Select(select) = &statement { - let positions: Vec = select - .columns - .iter() - .enumerate() - .filter_map(|(i, col)| { - if let SelectElement::Column(named) = col { - if protect_columns.contains(&named.name) { - Some(i) - } else { - None - } - } else { - None - } - }) - .collect(); - for row in &mut *rows { - for index in &positions { - if let Some(v) = row.get_mut(*index) { - if let MessageValue::Bytes(_) = v { - let protected = - Protected::from_encrypted_bytes_value( - v, - ) - .await?; - let new_value: MessageValue = protected - .unprotect( - &self.key_source, - &self.key_id, - ) - .await?; - *v = new_value; - invalidate_cache = true; - } else { - warn!("Tried decrypting non-blob column") - } - } - } - } - } + if let Some(columns) = self.get_protected_columns( table_name ) { + if let CassandraStatement::Select(select) = &statement { + invalidate_cache |= self.process_select( select, columns, rows ).await? } } } @@ -350,3 +381,31 @@ impl Transform for Protect { Ok(result) } } + +#[cfg(test)] +mod test { + use serde::Serialize; + use crate::message::MessageValue; + use crate::transforms::protect::key_management::{KeyManagement, KeyManager}; + use crate::transforms::protect::local_kek::LocalKeyManagement; + use crate::transforms::protect::Protected; + use sodiumoxide::crypto::secretbox::{Key, Nonce}; + + #[tokio::test(flavor = "multi_thread")] + //#[test] + async fn round_trip_test() { + if sodiumoxide::init().is_err() { + panic!( "could not init sodiumoxide"); + } + let kek = sodiumoxide::crypto::secretbox::xsalsa20poly1305::gen_key(); + let local_key_mgr = LocalKeyManagement{ kek, kek_id: "".to_string() }; + let key_mgr = KeyManager::Local(local_key_mgr); + let msg_value = MessageValue::Varchar("Hello World".to_string()); + let plain = Protected::Plaintext( msg_value.clone() ); + let encr = plain.protect( &key_mgr, "" ).await.unwrap(); + let new_msg = encr.unprotect( &key_mgr, "" ).await; + + assert_eq!( &msg_value, &new_msg.unwrap() ); + + } +} diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index d9d33d067..c2fa0ccf1 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -670,7 +670,7 @@ fn build_redis_key_from_cql3( build_query_redis_key_from_value_map( table_cache_schema, &value_map, - &select.table_name, + &select.table_name.to_string(), )?, build_query_redis_hash_from_value_map(table_cache_schema, &value_map, select)?, )) @@ -695,7 +695,7 @@ fn build_redis_key_from_cql3( build_query_redis_key_from_value_map( table_cache_schema, &value_map, - &insert.table_name, + &insert.table_name.to_string(), )?, Bytes::new(), )) @@ -706,7 +706,7 @@ fn build_redis_key_from_cql3( build_query_redis_key_from_value_map( table_cache_schema, &value_map, - &update.table_name, + &update.table_name.to_string(), )?, Bytes::new(), )) diff --git a/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs b/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs index 0a8950100..4dfabd0e4 100644 --- a/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs +++ b/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs @@ -1403,7 +1403,7 @@ mod protect { if let ResultValue::Varchar(value) = &result[0][2] { assert!(value.starts_with("{\"Ciphertext")); } else { - panic!("expectected 3rd column to be ResultValue::Varchar in {result:?}"); + panic!("expected 3rd column to be ResultValue::Varchar in {result:?}"); } assert_eq!(result[0][3], ResultValue::Int(42)); assert_eq!(result[0][4], ResultValue::Boolean(true)); From 154a3f4de8aa313f96a406220e3b053dcf5f8224 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Fri, 22 Apr 2022 13:39:46 +0100 Subject: [PATCH 32/60] fixed protect code --- Cargo.lock | 1 + shotover-proxy/Cargo.toml | 1 + shotover-proxy/src/frame/cassandra.rs | 9 +- shotover-proxy/src/transforms/protect/mod.rs | 367 +++++++++++------- .../cassandra_int_tests/basic_driver_tests.rs | 4 +- 5 files changed, 245 insertions(+), 137 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 83f4af890..e389421bf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2579,6 +2579,7 @@ dependencies = [ "rayon", "redis", "redis-protocol", + "regex", "reqwest", "rusoto_kms", "rusoto_signature", diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml index 7bab5492a..b5f7d2950 100644 --- a/shotover-proxy/Cargo.toml +++ b/shotover-proxy/Cargo.toml @@ -91,6 +91,7 @@ hex-literal = "0.3.3" nix = "0.23.0" reqwest = "0.11.6" metrics-util = "0.12.0" +regex = "1.5.5" [[bench]] diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index fd4fd174d..8016c8c44 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -953,7 +953,14 @@ impl ToCassandraType for Operand { fn from_string_value(value: &str) -> Option { // check for string types if value.starts_with('\'') || value.starts_with("$$") { - Some(CassandraType::Varchar(value.to_string())) + let mut chars = value.chars(); + chars.next(); + chars.next_back(); + if value.starts_with('$') { + chars.next(); + chars.next_back(); + } + Some(CassandraType::Varchar(chars.as_str().to_string())) } else if value.starts_with("0X") || value.starts_with("0x") { let mut chars = value.chars(); chars.next(); diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index 3b22ed23b..c916eeede 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -18,6 +18,7 @@ use sodiumoxide::crypto::secretbox::{Key, Nonce}; use sodiumoxide::hex; use std::collections::HashMap; use tracing::warn; +use tracing_log::log::debug; mod aws_kms; mod key_management; @@ -34,6 +35,35 @@ pub struct Protect { } impl Protect { + + /// encodes a Protected object into a byte array. This is here to centeralize the serde for + /// the Protected object. + /// Returns an error if a Plaintext Protected object is passed + fn encode(protected : &Protected ) -> Result> { + match protected { + Protected::Plaintext(_) => { Err( anyhow!("can not encode plain text")) } + Protected::Ciphertext {..} => { + match serde_json::to_vec(protected) { + Ok(data) => { Ok( data )} + Err(e) => { Err( anyhow!( "{:?}", e ))} + } } + } + } + + /// decodes a byte array into the Protected object. This is here to centeralize the serde for + /// the Protected object. + fn decode(data : &[u8] ) -> Result { + let result = serde_json::from_slice( &data ); + if result.is_err() { + Err(anyhow!( "{:?}", result.err())) + } else { + let decoded :Protected = result.unwrap(); + Ok(decoded) + } + } + + /// get the list of protected columns for the specified table name. Will return `None` if no columns + /// are defined for the table. fn get_protected_columns(&self, table_name : &FQName ) -> Option<&Vec> { // TODO replace "" with cached keyspace name if let Some(tables) = self.keyspace_table_columns.get(table_name.extract_keyspace( "")) @@ -44,10 +74,92 @@ impl Protect { } } - /// processes the select statement to modify the rows. returns true if the rows were modified + /// extractes the protected object from the message value. Resulting object is a Protected::Ciphertext + fn extract_protected(&self, value : &MessageValue ) -> Result { + match value { + MessageValue::Bytes(b) => { + Protect::decode(&b[..] ) + } + MessageValue::Varchar(hex_value) => { + let byte_value = hex::decode(hex_value ); + Protect::decode(&byte_value.unwrap() ) + } + _ => Err(anyhow!( + "Could not get bytes to decrypt - wrong value type {:?}", + value + )), + } + } + + /// determines if columns in the CassandraStatement need to be encrypted and encrypts them. Returns `true` if any columns were changed. + /// * `statement` the statement to encrypt. + /// * `columns` the column names to encrypt. + /// * `key_source` the key manager with encryption keys. + /// * `key_id` the key within the manager to use. + async fn encrypt_columns(&self, + statement: &mut CassandraStatement + ) -> Result { + let mut data_changed = false; + if let Some(table_name) = CQLStatement::get_table_name(statement) { + if let Some(columns) = self.get_protected_columns(table_name) { + match statement { + CassandraStatement::Insert(insert) => { + // get the indices of the inserted protected columns + let indices: Vec = insert + .columns + .iter() + .enumerate() + .filter_map(|(i, col_name)| { + if columns.contains(col_name) { + Some(i) + } else { + None + } + }) + .collect(); + // if there are columns process them + if ! indices.is_empty() { + match &mut insert.values { + InsertValues::Values(value_operands) => { + for idx in indices { + let mut protected = + Protected::Plaintext(MessageValue::from(&value_operands[idx])); + protected = protected.protect(&self.key_source, &self.key_id).await?; + value_operands[idx] = Operand::from(&protected); + data_changed = true + } + } + InsertValues::Json(_) => { + // TODO parse json and encrypt. + } + } + } + } + CassandraStatement::Update(update) => { + for assignment in &mut update.assignments { + if columns.contains(&assignment.name.column) { + let mut protected = Protected::Plaintext(MessageValue::from(&assignment.value)); + protected = protected.protect(&self.key_source, &self.key_id).await?; + assignment.value = Operand::from(&protected); + data_changed = true; + } + } + } + _ => { + // no other statement are modified + } + } + } + } + Ok(data_changed) + } + + + /// processes the select statement to modify the rows. returns `true` if the rows were modified async fn process_select(&self, select : &Select, columns : &Vec, rows : &mut Vec>) -> Result { let mut modified = false; + // get the positions of the protected columns in the result let positions: Vec = select .columns .iter() @@ -64,23 +176,19 @@ impl Protect { } }) .collect(); - for row in &mut *rows { - for index in &positions { - if let Some(message_value) = row.get_mut(*index) { - let protected = Protected::extract_result(message_value); - - let new_value: MessageValue = protected.unwrap() - .unprotect( - &self.key_source, - &self.key_id, - ) - .await?; - *message_value = new_value; - modified = true; + // only do the work if there are columns we are interested in + if positions.len() > 0 { + for row in &mut *rows { + for index in &positions { + if let Some(message_value) = row.get_mut(*index) { + let protected = self.extract_protected(message_value).unwrap(); + let new_value = protected.unprotect( &self.key_source, &self.key_id).await?; + *message_value = new_value; + modified = true; + } } } } - Ok(modified) } @@ -114,20 +222,21 @@ pub enum Protected { }, } -fn encrypt(plaintext: Vec, sym_key: &Key) -> (Vec, Nonce) { +/// encrypts the message value +fn encrypt(message_value : &MessageValue, sym_key: &Key) -> (Vec, Nonce) { + let ser = bincode::serialize( message_value ); let nonce = secretbox::gen_nonce(); - let ciphertext = secretbox::seal(&plaintext, &nonce, sym_key); + let ciphertext = secretbox::seal(&ser.unwrap(), &nonce, sym_key); (ciphertext, nonce) } +/// decrypts a message value fn decrypt(ciphertext: Vec, nonce: Nonce, sym_key: &Key) -> Result { let decrypted_bytes = secretbox::open(&ciphertext, &nonce, sym_key).map_err(|_| anyhow!("couldn't open box"))?; //TODO make error handing better here - failure here indicates a authenticity failure - let decrypted_value: MessageValue = serde_json::from_slice(decrypted_bytes.as_slice()) - .map_err(|_| anyhow!("couldn't open box"))?; - // let decrypted_value: MessageValue = - // bincode::deserialize(&decrypted_bytes).map_err(|_| anyhow!("couldn't open box"))?; + let decrypted_value: MessageValue = + bincode::deserialize(&decrypted_bytes).map_err(|_| anyhow!("couldn't open box"))?; Ok(decrypted_value) } @@ -152,47 +261,12 @@ impl From<&Protected> for Operand { Protected::Plaintext(_) => panic!( "tried to move unencrypted value to plaintext without explicitly calling decrypt" ), - Protected::Ciphertext { .. } => Operand::Const(format!("'{}'", hex::encode(serde_json::to_vec(&p).unwrap()))), + Protected::Ciphertext { .. } => Operand::Const(format!("'{}'", hex::encode(Protect::encode(&p).unwrap()))), } } } impl Protected { - fn extract_result(value : &MessageValue ) -> Result { - match value { - MessageValue::Bytes(b) => { - // let protected_something: Protected = serde_json::from_slice(b.bytes())?; - let protected_something = bincode::deserialize(b); - if protected_something.is_err() { - Err(anyhow!( "{:?}", protected_something.err())) - } else { - Ok(protected_something.unwrap()) - } - } - MessageValue::Varchar(s) => { - warn!("varchar {}", s ); - let mut hex_value = s.chars(); - hex_value.next(); - hex_value.next_back(); - - - let byte_value = hex::decode(hex_value.as_str()); - //let plain = decrypt( byte_value); - //let x = bincode::deserialize(&byte_value.unwrap()); - - let protected_something = bincode::deserialize(&byte_value.unwrap()); - if protected_something.is_err() { - Err(anyhow!( "{:?}", protected_something.err())) - } else { - Ok(protected_something.unwrap()) - } - } - _ => Err(anyhow!( - "Could not get bytes to decrypt - wrong value type {:?}", - value - )), - } - } // TODO should this actually return self (we are sealing the plaintext value, but we don't swap out the plaintext?? pub async fn protect(self, key_management: &KeyManager, key_id: &str) -> Result { @@ -200,9 +274,8 @@ impl Protected { .cached_get_key(key_id.to_string(), None, None) .await?; match &self { - Protected::Plaintext(p) => { - // let (cipher, nonce) = encrypt(serde_json::to_string(p).unwrap(), &sym_key.plaintext); - let (cipher, nonce) = encrypt(bincode::serialize(&p).unwrap(), &sym_key.plaintext); + Protected::Plaintext(message_value) => { + let (cipher, nonce) = encrypt(&message_value, &sym_key.plaintext); Ok(Protected::Ciphertext { cipher, nonce, @@ -252,59 +325,6 @@ impl ProtectConfig { } } -/// determines if columns in the CassandraStatement need to be encrypted and encrypts them. Returns `true` if any columns were changed. -async fn encrypt_columns( - statement: &mut CassandraStatement, - columns: &[String], - key_source: &KeyManager, - key_id: &str, -) -> Result { - let mut data_changed = false; - match statement { - CassandraStatement::Insert(insert) => { - let indices: Vec = insert - .columns - .iter() - .enumerate() - .filter_map(|(i, col_name)| { - if columns.contains(col_name) { - Some(i) - } else { - None - } - }) - .collect(); - match &mut insert.values { - InsertValues::Values(value_operands) => { - for idx in indices { - let mut protected = - Protected::Plaintext(MessageValue::from(&value_operands[idx])); - protected = protected.protect(key_source, key_id).await?; - value_operands[idx] = Operand::from(&protected); - data_changed = true - } - } - InsertValues::Json(_) => { - // TODO parse json and encrypt. - } - } - } - CassandraStatement::Update(update) => { - for assignment in &mut update.assignments { - if columns.contains(&assignment.name.column) { - let mut protected = Protected::Plaintext(MessageValue::from(&assignment.value)); - protected = protected.protect(key_source, key_id).await?; - assignment.value = Operand::from(&protected); - data_changed = true; - } - } - } - _ => { - // no other statement are modified - } - } - Ok(data_changed) -} #[async_trait] @@ -321,22 +341,12 @@ impl Transform for Protect { { for cql_statement in &mut query.statements { let statement = &mut cql_statement.statement; - if let Some(table_name) = CQLStatement::get_table_name(statement) { - if let Some(columns) = self.get_protected_columns( table_name ) { - data_changed = encrypt_columns( - statement, - &columns, - &self.key_source, - &self.key_id, - ).await?; - if data_changed { - warn!( "statement changed to {}", statement ); - } - } + data_changed |= self.encrypt_columns(statement).await.unwrap(); + if data_changed { + debug!( "statement changed to {}", statement ); } } } - if data_changed { message.invalidate_cache(); } @@ -384,12 +394,44 @@ impl Transform for Protect { #[cfg(test)] mod test { - use serde::Serialize; + use std::collections::HashMap; + use bytes::Bytes; + use cql3_parser::cassandra_statement::CassandraStatement; + use cql3_parser::common::Operand; + use cql3_parser::insert::InsertValues; + use futures::future::ok; + use sodiumoxide::crypto::secretbox::Nonce; use crate::message::MessageValue; - use crate::transforms::protect::key_management::{KeyManagement, KeyManager}; + use crate::transforms::protect::key_management::KeyManager; use crate::transforms::protect::local_kek::LocalKeyManagement; - use crate::transforms::protect::Protected; - use sodiumoxide::crypto::secretbox::{Key, Nonce}; + use crate::transforms::protect::{Protect, Protected}; + use crate::frame::CQL; + + #[test] + fn test_serde() { + let n : [u8;24] = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]; + let ocipher =Bytes::from( "this would be encrypted data" ).to_vec(); + let ononce = Nonce::from_slice(&n).unwrap(); + let oenc_dek = Bytes::from( "this would be enc_dek" ).to_vec(); + let okek_id = "The KEK id".to_string(); + + let protected = Protected::Ciphertext { + cipher : ocipher.clone(), + nonce : ononce.clone(), + enc_dek : oenc_dek.clone(), + kek_id : okek_id.clone(), + }; + let encoded = Protect::encode( &protected ).unwrap(); + let decoded = Protect::decode( &encoded ).unwrap(); + if let Protected::Ciphertext { cipher,nonce,enc_dek,kek_id } = decoded { + assert_eq!( &ocipher, &cipher); + assert_eq!( &ononce, &nonce); + assert_eq!( &oenc_dek, &enc_dek); + assert_eq!( &okek_id, &kek_id); + } else { + panic!( "not a Ciphertext") + } + } #[tokio::test(flavor = "multi_thread")] //#[test] @@ -397,15 +439,70 @@ mod test { if sodiumoxide::init().is_err() { panic!( "could not init sodiumoxide"); } + + // verify low level round trip works. let kek = sodiumoxide::crypto::secretbox::xsalsa20poly1305::gen_key(); let local_key_mgr = LocalKeyManagement{ kek, kek_id: "".to_string() }; - let key_mgr = KeyManager::Local(local_key_mgr); + + let cols = vec!["col1".to_string()]; + let mut tables = HashMap::new(); + tables.insert( "test_table".to_string() , cols.clone() ); + let mut keyspace_table_columns = HashMap::new(); + keyspace_table_columns.insert("".to_string(),tables); + let protect = Protect{ + keyspace_table_columns, + key_source: KeyManager::Local(local_key_mgr), + key_id: "".to_string() + }; + + // test protect/unprotect works let msg_value = MessageValue::Varchar("Hello World".to_string()); let plain = Protected::Plaintext( msg_value.clone() ); - let encr = plain.protect( &key_mgr, "" ).await.unwrap(); - let new_msg = encr.unprotect( &key_mgr, "" ).await; - - assert_eq!( &msg_value, &new_msg.unwrap() ); + let encr = plain.protect( &protect.key_source, &protect.key_id ).await.unwrap(); + let new_msg = encr.unprotect( &protect.key_source, &protect.key_id ).await.unwrap(); + assert_eq!( &msg_value, &new_msg ); + + // test insert change is reversed on select + let stmt_txt = "insert into test_table (col1, col2) VALUES ('Hello World', 'i am clean')"; + let mut cql = CQL::parse_from_string(stmt_txt); + let statement = &mut cql.statements[0].statement; + let data_changed = protect.encrypt_columns(statement).await.unwrap(); + assert!( data_changed ); + + let mut encr_value = String::new(); + if let CassandraStatement::Insert(insert) = statement { + if let InsertValues::Values(operands) = &insert.values { + if let Operand::Const( value ) = &operands[0] { + encr_value = value.clone(); + assert!( ! value.eq( "Hello World'") ); + } else { + panic!( "Not a const value"); + } + } else { + panic!( "Not a InsertValues::Values object"); + } + } else { + panic!("not an INSERT"); + } + // remove the quotes + let mut hex_value = encr_value.chars(); + hex_value.next(); + hex_value.next_back(); + let s = hex_value.as_str().to_string(); + let mv = MessageValue::Varchar(s); + // build the row + let row = vec![mv]; + let mut rows = vec![row]; + + let stmt_txt = "select col1 from test_table where col2='i am clean'"; + let cql = CQL::parse_from_string(stmt_txt); + let statement = &cql.statements[0].statement; + + if let CassandraStatement::Select(select) = statement { + let result = protect.process_select(select, &cols, &mut rows ).await; + assert!( result.unwrap() ); + assert_eq!( &msg_value, &rows[0][0] ); + } } } diff --git a/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs b/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs index 4dfabd0e4..b4757a605 100644 --- a/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs +++ b/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs @@ -1354,6 +1354,7 @@ mod cache { mod protect { use crate::helpers::cassandra::{execute_query, run_query, ResultValue}; use cassandra_cpp::Session; + use regex::Regex; pub fn test(shotover_session: &Session, direct_session: &Session) { run_query(shotover_session, "CREATE KEYSPACE test_protect_keyspace WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };"); @@ -1401,7 +1402,8 @@ mod protect { assert_eq!(result[0][0], ResultValue::Varchar("pk1".into())); assert_eq!(result[0][1], ResultValue::Varchar("cluster".into())); if let ResultValue::Varchar(value) = &result[0][2] { - assert!(value.starts_with("{\"Ciphertext")); + let re = Regex::new( r"^[0-9a-fA-F]+$").unwrap(); + assert!( re.is_match( value )); } else { panic!("expected 3rd column to be ResultValue::Varchar in {result:?}"); } From aff9552b13ba9c219dfe87a99e3fa0baf5e1c459 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Fri, 22 Apr 2022 14:06:33 +0100 Subject: [PATCH 33/60] fixed cargo errors --- shotover-proxy/Cargo.toml | 2 +- shotover-proxy/src/transforms/protect/mod.rs | 71 +++++++++----------- 2 files changed, 34 insertions(+), 39 deletions(-) diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml index b5f7d2950..713c19e5b 100644 --- a/shotover-proxy/Cargo.toml +++ b/shotover-proxy/Cargo.toml @@ -9,7 +9,7 @@ license = "Apache-2.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["alpha-transforms"] +#default = ["alpha-transforms"] # Include WIP alpha transforms in the public API alpha-transforms = [] diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index c916eeede..8408d361b 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -17,7 +17,6 @@ use sodiumoxide::crypto::secretbox; use sodiumoxide::crypto::secretbox::{Key, Nonce}; use sodiumoxide::hex; use std::collections::HashMap; -use tracing::warn; use tracing_log::log::debug; mod aws_kms; @@ -53,12 +52,10 @@ impl Protect { /// decodes a byte array into the Protected object. This is here to centeralize the serde for /// the Protected object. fn decode(data : &[u8] ) -> Result { - let result = serde_json::from_slice( &data ); - if result.is_err() { - Err(anyhow!( "{:?}", result.err())) - } else { - let decoded :Protected = result.unwrap(); - Ok(decoded) + let result = serde_json::from_slice( data ); + match result { + Ok(decoded) => { Ok(decoded)} + Err(e) => {Err(anyhow!( "{:?}", e))} } } @@ -156,7 +153,7 @@ impl Protect { /// processes the select statement to modify the rows. returns `true` if the rows were modified - async fn process_select(&self, select : &Select, columns : &Vec, rows : &mut Vec>) -> Result { + async fn process_select(&self, select : &Select, columns : &[String], rows : &mut Vec>) -> Result { let mut modified = false; // get the positions of the protected columns in the result @@ -177,7 +174,7 @@ impl Protect { }) .collect(); // only do the work if there are columns we are interested in - if positions.len() > 0 { + if ! positions.is_empty() { for row in &mut *rows { for index in &positions { if let Some(message_value) = row.get_mut(*index) { @@ -256,12 +253,12 @@ impl From for MessageValue { } impl From<&Protected> for Operand { - fn from(p: &Protected) -> Self { - match p { + fn from(protected: &Protected) -> Self { + match protected { Protected::Plaintext(_) => panic!( "tried to move unencrypted value to plaintext without explicitly calling decrypt" ), - Protected::Ciphertext { .. } => Operand::Const(format!("'{}'", hex::encode(Protect::encode(&p).unwrap()))), + Protected::Ciphertext { .. } => Operand::Const(format!("'{}'", hex::encode(Protect::encode(protected).unwrap()))), } } } @@ -275,7 +272,7 @@ impl Protected { .await?; match &self { Protected::Plaintext(message_value) => { - let (cipher, nonce) = encrypt(&message_value, &sym_key.plaintext); + let (cipher, nonce) = encrypt(message_value, &sym_key.plaintext); Ok(Protected::Ciphertext { cipher, nonce, @@ -399,7 +396,6 @@ mod test { use cql3_parser::cassandra_statement::CassandraStatement; use cql3_parser::common::Operand; use cql3_parser::insert::InsertValues; - use futures::future::ok; use sodiumoxide::crypto::secretbox::Nonce; use crate::message::MessageValue; use crate::transforms::protect::key_management::KeyManager; @@ -417,7 +413,7 @@ mod test { let protected = Protected::Ciphertext { cipher : ocipher.clone(), - nonce : ononce.clone(), + nonce : ononce, enc_dek : oenc_dek.clone(), kek_id : okek_id.clone(), }; @@ -469,12 +465,30 @@ mod test { let data_changed = protect.encrypt_columns(statement).await.unwrap(); assert!( data_changed ); - let mut encr_value = String::new(); if let CassandraStatement::Insert(insert) = statement { if let InsertValues::Values(operands) = &insert.values { - if let Operand::Const( value ) = &operands[0] { - encr_value = value.clone(); - assert!( ! value.eq( "Hello World'") ); + if let Operand::Const( encr_value ) = &operands[0] { + assert!( ! encr_value.eq( "Hello World'") ); + // remove the quotes + let mut hex_value = encr_value.chars(); + hex_value.next(); + hex_value.next_back(); + let s = hex_value.as_str().to_string(); + let mv = MessageValue::Varchar(s); + // build the row + let row = vec![mv]; + let mut rows = vec![row]; + + let stmt_txt = "select col1 from test_table where col2='i am clean'"; + let cql = CQL::parse_from_string(stmt_txt); + let statement = &cql.statements[0].statement; + + if let CassandraStatement::Select(select) = statement { + let result = protect.process_select(select, &cols, &mut rows ).await; + assert!( result.unwrap() ); + assert_eq!( &msg_value, &rows[0][0] ); + } + } else { panic!( "Not a const value"); } @@ -485,24 +499,5 @@ mod test { panic!("not an INSERT"); } - // remove the quotes - let mut hex_value = encr_value.chars(); - hex_value.next(); - hex_value.next_back(); - let s = hex_value.as_str().to_string(); - let mv = MessageValue::Varchar(s); - // build the row - let row = vec![mv]; - let mut rows = vec![row]; - - let stmt_txt = "select col1 from test_table where col2='i am clean'"; - let cql = CQL::parse_from_string(stmt_txt); - let statement = &cql.statements[0].statement; - - if let CassandraStatement::Select(select) = statement { - let result = protect.process_select(select, &cols, &mut rows ).await; - assert!( result.unwrap() ); - assert_eq!( &msg_value, &rows[0][0] ); - } } } From 4fe3b705982df6010e9878fcbb88b1fb4a672b5c Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Mon, 25 Apr 2022 05:59:11 +0100 Subject: [PATCH 34/60] fixed cargo issues --- shotover-proxy/src/frame/cassandra.rs | 2 +- shotover-proxy/src/message/mod.rs | 6 +- shotover-proxy/src/transforms/protect/mod.rs | 223 ++++++++++-------- .../cassandra_int_tests/basic_driver_tests.rs | 6 +- 4 files changed, 131 insertions(+), 106 deletions(-) diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index 8016c8c44..14d17b099 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -22,7 +22,7 @@ use cassandra_protocol::types::value::Value; use cassandra_protocol::types::{CBytes, CBytesShort, CInt, CLong}; use cql3_parser::cassandra_ast::CassandraAST; use cql3_parser::cassandra_statement::CassandraStatement; -use cql3_parser::common::{Operand, RelationElement, FQName}; +use cql3_parser::common::{FQName, Operand, RelationElement}; use cql3_parser::insert::InsertValues; use cql3_parser::update::AssignmentOperator; use itertools::Itertools; diff --git a/shotover-proxy/src/message/mod.rs b/shotover-proxy/src/message/mod.rs index d9a29d792..3e43db9d3 100644 --- a/shotover-proxy/src/message/mod.rs +++ b/shotover-proxy/src/message/mod.rs @@ -177,7 +177,11 @@ impl Message { /// None if the statements do not contain table names. pub fn get_table_names(&mut self) -> Vec { match self.frame() { - Some(Frame::Cassandra(cassandra)) => cassandra.get_table_names().iter().map( |n| n.to_string()).collect_vec(), + Some(Frame::Cassandra(cassandra)) => cassandra + .get_table_names() + .iter() + .map(|n| n.to_string()) + .collect_vec(), Some(Frame::Redis(_)) => unimplemented!(), Some(Frame::None) => vec![], _ => unreachable!(), diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index 8408d361b..c2c1e3d5a 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -34,36 +34,36 @@ pub struct Protect { } impl Protect { - /// encodes a Protected object into a byte array. This is here to centeralize the serde for /// the Protected object. /// Returns an error if a Plaintext Protected object is passed - fn encode(protected : &Protected ) -> Result> { + fn encode(protected: &Protected) -> Result> { match protected { - Protected::Plaintext(_) => { Err( anyhow!("can not encode plain text")) } - Protected::Ciphertext {..} => { - match serde_json::to_vec(protected) { - Ok(data) => { Ok( data )} - Err(e) => { Err( anyhow!( "{:?}", e ))} - } } + Protected::Plaintext(_) => Err(anyhow!("can not encode plain text")), + Protected::Ciphertext { .. } => match serde_json::to_vec(protected) { + Ok(data) => Ok(data), + Err(e) => Err(anyhow!("{:?}", e)), + }, } } /// decodes a byte array into the Protected object. This is here to centeralize the serde for /// the Protected object. - fn decode(data : &[u8] ) -> Result { - let result = serde_json::from_slice( data ); + fn decode(data: &[u8]) -> Result { + let result = serde_json::from_slice(data); match result { - Ok(decoded) => { Ok(decoded)} - Err(e) => {Err(anyhow!( "{:?}", e))} + Ok(decoded) => Ok(decoded), + Err(e) => Err(anyhow!("{:?}", e)), } } /// get the list of protected columns for the specified table name. Will return `None` if no columns /// are defined for the table. - fn get_protected_columns(&self, table_name : &FQName ) -> Option<&Vec> { + fn get_protected_columns(&self, table_name: &FQName) -> Option<&Vec> { // TODO replace "" with cached keyspace name - if let Some(tables) = self.keyspace_table_columns.get(table_name.extract_keyspace( "")) + if let Some(tables) = self + .keyspace_table_columns + .get(table_name.extract_keyspace("")) { tables.get(&table_name.name) } else { @@ -72,14 +72,12 @@ impl Protect { } /// extractes the protected object from the message value. Resulting object is a Protected::Ciphertext - fn extract_protected(&self, value : &MessageValue ) -> Result { + fn extract_protected(&self, value: &MessageValue) -> Result { match value { - MessageValue::Bytes(b) => { - Protect::decode(&b[..] ) - } + MessageValue::Bytes(b) => Protect::decode(&b[..]), MessageValue::Varchar(hex_value) => { - let byte_value = hex::decode(hex_value ); - Protect::decode(&byte_value.unwrap() ) + let byte_value = hex::decode(hex_value); + Protect::decode(&byte_value.unwrap()) } _ => Err(anyhow!( "Could not get bytes to decrypt - wrong value type {:?}", @@ -93,9 +91,7 @@ impl Protect { /// * `columns` the column names to encrypt. /// * `key_source` the key manager with encryption keys. /// * `key_id` the key within the manager to use. - async fn encrypt_columns(&self, - statement: &mut CassandraStatement - ) -> Result { + async fn encrypt_columns(&self, statement: &mut CassandraStatement) -> Result { let mut data_changed = false; if let Some(table_name) = CQLStatement::get_table_name(statement) { if let Some(columns) = self.get_protected_columns(table_name) { @@ -115,13 +111,16 @@ impl Protect { }) .collect(); // if there are columns process them - if ! indices.is_empty() { + if !indices.is_empty() { match &mut insert.values { InsertValues::Values(value_operands) => { for idx in indices { - let mut protected = - Protected::Plaintext(MessageValue::from(&value_operands[idx])); - protected = protected.protect(&self.key_source, &self.key_id).await?; + let mut protected = Protected::Plaintext( + MessageValue::from(&value_operands[idx]), + ); + protected = protected + .protect(&self.key_source, &self.key_id) + .await?; value_operands[idx] = Operand::from(&protected); data_changed = true } @@ -135,8 +134,10 @@ impl Protect { CassandraStatement::Update(update) => { for assignment in &mut update.assignments { if columns.contains(&assignment.name.column) { - let mut protected = Protected::Plaintext(MessageValue::from(&assignment.value)); - protected = protected.protect(&self.key_source, &self.key_id).await?; + let mut protected = + Protected::Plaintext(MessageValue::from(&assignment.value)); + protected = + protected.protect(&self.key_source, &self.key_id).await?; assignment.value = Operand::from(&protected); data_changed = true; } @@ -151,9 +152,13 @@ impl Protect { Ok(data_changed) } - /// processes the select statement to modify the rows. returns `true` if the rows were modified - async fn process_select(&self, select : &Select, columns : &[String], rows : &mut Vec>) -> Result { + async fn process_select( + &self, + select: &Select, + columns: &[String], + rows: &mut Vec>, + ) -> Result { let mut modified = false; // get the positions of the protected columns in the result @@ -174,12 +179,12 @@ impl Protect { }) .collect(); // only do the work if there are columns we are interested in - if ! positions.is_empty() { + if !positions.is_empty() { for row in &mut *rows { for index in &positions { if let Some(message_value) = row.get_mut(*index) { let protected = self.extract_protected(message_value).unwrap(); - let new_value = protected.unprotect( &self.key_source, &self.key_id).await?; + let new_value = protected.unprotect(&self.key_source, &self.key_id).await?; *message_value = new_value; modified = true; } @@ -188,7 +193,6 @@ impl Protect { } Ok(modified) } - } #[derive(Clone)] @@ -220,8 +224,8 @@ pub enum Protected { } /// encrypts the message value -fn encrypt(message_value : &MessageValue, sym_key: &Key) -> (Vec, Nonce) { - let ser = bincode::serialize( message_value ); +fn encrypt(message_value: &MessageValue, sym_key: &Key) -> (Vec, Nonce) { + let ser = bincode::serialize(message_value); let nonce = secretbox::gen_nonce(); let ciphertext = secretbox::seal(&ser.unwrap(), &nonce, sym_key); (ciphertext, nonce) @@ -232,8 +236,8 @@ fn decrypt(ciphertext: Vec, nonce: Nonce, sym_key: &Key) -> Result for Operand { Protected::Plaintext(_) => panic!( "tried to move unencrypted value to plaintext without explicitly calling decrypt" ), - Protected::Ciphertext { .. } => Operand::Const(format!("'{}'", hex::encode(Protect::encode(protected).unwrap()))), + Protected::Ciphertext { .. } => Operand::Const(format!( + "'{}'", + hex::encode(Protect::encode(protected).unwrap()) + )), } } } impl Protected { - // TODO should this actually return self (we are sealing the plaintext value, but we don't swap out the plaintext?? pub async fn protect(self, key_management: &KeyManager, key_id: &str) -> Result { let sym_key = key_management @@ -302,11 +308,10 @@ impl Protected { .await?; let result = decrypt(cipher, nonce, &sym_key.plaintext); if result.is_err() { - Err( anyhow!( "{}", result.err().unwrap() )) + Err(anyhow!("{}", result.err().unwrap())) } else { - Ok( result.unwrap()) + Ok(result.unwrap()) } - } } } @@ -322,8 +327,6 @@ impl ProtectConfig { } } - - #[async_trait] impl Transform for Protect { async fn transform<'a>(&'a mut self, mut message_wrapper: Wrapper<'a>) -> ChainResponse { @@ -340,7 +343,7 @@ impl Transform for Protect { let statement = &mut cql_statement.statement; data_changed |= self.encrypt_columns(statement).await.unwrap(); if data_changed { - debug!( "statement changed to {}", statement ); + debug!("statement changed to {}", statement); } } } @@ -355,23 +358,26 @@ impl Transform for Protect { for (response, request) in result.iter_mut().zip(original_messages.iter_mut()) { let mut invalidate_cache = false; if let Some(Frame::Cassandra(CassandraFrame { - operation: CassandraOperation::Query { query, .. }, - .. - })) = request.frame() { + operation: CassandraOperation::Query { query, .. }, + .. + })) = request.frame() + { if let Some(Frame::Cassandra(CassandraFrame { - operation: - CassandraOperation::Result(CassandraResult::Rows { - value: MessageValue::Rows(rows), - .. - }), - .. - })) = response.frame() { + operation: + CassandraOperation::Result(CassandraResult::Rows { + value: MessageValue::Rows(rows), + .. + }), + .. + })) = response.frame() + { for cql_statement in &mut query.statements { let statement = &mut cql_statement.statement; if let Some(table_name) = CQLStatement::get_table_name(statement) { - if let Some(columns) = self.get_protected_columns( table_name ) { + if let Some(columns) = self.get_protected_columns(table_name) { if let CassandraStatement::Select(select) = &statement { - invalidate_cache |= self.process_select( select, columns, rows ).await? + invalidate_cache |= + self.process_select(select, columns, rows).await? } } } @@ -391,41 +397,49 @@ impl Transform for Protect { #[cfg(test)] mod test { - use std::collections::HashMap; + use crate::frame::CQL; + use crate::message::MessageValue; + use crate::transforms::protect::key_management::KeyManager; + use crate::transforms::protect::local_kek::LocalKeyManagement; + use crate::transforms::protect::{Protect, Protected}; use bytes::Bytes; use cql3_parser::cassandra_statement::CassandraStatement; use cql3_parser::common::Operand; use cql3_parser::insert::InsertValues; use sodiumoxide::crypto::secretbox::Nonce; - use crate::message::MessageValue; - use crate::transforms::protect::key_management::KeyManager; - use crate::transforms::protect::local_kek::LocalKeyManagement; - use crate::transforms::protect::{Protect, Protected}; - use crate::frame::CQL; + use std::collections::HashMap; #[test] fn test_serde() { - let n : [u8;24] = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]; - let ocipher =Bytes::from( "this would be encrypted data" ).to_vec(); + let n: [u8; 24] = [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ]; + let ocipher = Bytes::from("this would be encrypted data").to_vec(); let ononce = Nonce::from_slice(&n).unwrap(); - let oenc_dek = Bytes::from( "this would be enc_dek" ).to_vec(); + let oenc_dek = Bytes::from("this would be enc_dek").to_vec(); let okek_id = "The KEK id".to_string(); let protected = Protected::Ciphertext { - cipher : ocipher.clone(), - nonce : ononce, - enc_dek : oenc_dek.clone(), - kek_id : okek_id.clone(), + cipher: ocipher.clone(), + nonce: ononce, + enc_dek: oenc_dek.clone(), + kek_id: okek_id.clone(), }; - let encoded = Protect::encode( &protected ).unwrap(); - let decoded = Protect::decode( &encoded ).unwrap(); - if let Protected::Ciphertext { cipher,nonce,enc_dek,kek_id } = decoded { - assert_eq!( &ocipher, &cipher); - assert_eq!( &ononce, &nonce); - assert_eq!( &oenc_dek, &enc_dek); - assert_eq!( &okek_id, &kek_id); + let encoded = Protect::encode(&protected).unwrap(); + let decoded = Protect::decode(&encoded).unwrap(); + if let Protected::Ciphertext { + cipher, + nonce, + enc_dek, + kek_id, + } = decoded + { + assert_eq!(&ocipher, &cipher); + assert_eq!(&ononce, &nonce); + assert_eq!(&oenc_dek, &enc_dek); + assert_eq!(&okek_id, &kek_id); } else { - panic!( "not a Ciphertext") + panic!("not a Ciphertext") } } @@ -433,47 +447,56 @@ mod test { //#[test] async fn round_trip_test() { if sodiumoxide::init().is_err() { - panic!( "could not init sodiumoxide"); + panic!("could not init sodiumoxide"); } // verify low level round trip works. let kek = sodiumoxide::crypto::secretbox::xsalsa20poly1305::gen_key(); - let local_key_mgr = LocalKeyManagement{ kek, kek_id: "".to_string() }; + let local_key_mgr = LocalKeyManagement { + kek, + kek_id: "".to_string(), + }; let cols = vec!["col1".to_string()]; let mut tables = HashMap::new(); - tables.insert( "test_table".to_string() , cols.clone() ); + tables.insert("test_table".to_string(), cols.clone()); let mut keyspace_table_columns = HashMap::new(); - keyspace_table_columns.insert("".to_string(),tables); - let protect = Protect{ + keyspace_table_columns.insert("".to_string(), tables); + let protect = Protect { keyspace_table_columns, key_source: KeyManager::Local(local_key_mgr), - key_id: "".to_string() + key_id: "".to_string(), }; // test protect/unprotect works - let msg_value = MessageValue::Varchar("Hello World".to_string()); - let plain = Protected::Plaintext( msg_value.clone() ); - let encr = plain.protect( &protect.key_source, &protect.key_id ).await.unwrap(); - let new_msg = encr.unprotect( &protect.key_source, &protect.key_id ).await.unwrap(); - assert_eq!( &msg_value, &new_msg ); + let msg_value = MessageValue::Varchar("Hello World".to_string()); + let plain = Protected::Plaintext(msg_value.clone()); + let encr = plain + .protect(&protect.key_source, &protect.key_id) + .await + .unwrap(); + let new_msg = encr + .unprotect(&protect.key_source, &protect.key_id) + .await + .unwrap(); + assert_eq!(&msg_value, &new_msg); // test insert change is reversed on select let stmt_txt = "insert into test_table (col1, col2) VALUES ('Hello World', 'i am clean')"; let mut cql = CQL::parse_from_string(stmt_txt); let statement = &mut cql.statements[0].statement; let data_changed = protect.encrypt_columns(statement).await.unwrap(); - assert!( data_changed ); + assert!(data_changed); if let CassandraStatement::Insert(insert) = statement { if let InsertValues::Values(operands) = &insert.values { - if let Operand::Const( encr_value ) = &operands[0] { - assert!( ! encr_value.eq( "Hello World'") ); + if let Operand::Const(encr_value) = &operands[0] { + assert!(!encr_value.eq("Hello World'")); // remove the quotes let mut hex_value = encr_value.chars(); hex_value.next(); hex_value.next_back(); - let s = hex_value.as_str().to_string(); + let s = hex_value.as_str().to_string(); let mv = MessageValue::Varchar(s); // build the row let row = vec![mv]; @@ -484,20 +507,18 @@ mod test { let statement = &cql.statements[0].statement; if let CassandraStatement::Select(select) = statement { - let result = protect.process_select(select, &cols, &mut rows ).await; - assert!( result.unwrap() ); - assert_eq!( &msg_value, &rows[0][0] ); + let result = protect.process_select(select, &cols, &mut rows).await; + assert!(result.unwrap()); + assert_eq!(&msg_value, &rows[0][0]); } - } else { - panic!( "Not a const value"); + panic!("Not a const value"); } } else { - panic!( "Not a InsertValues::Values object"); + panic!("Not a InsertValues::Values object"); } } else { panic!("not an INSERT"); } - } } diff --git a/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs b/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs index b4757a605..4556aeba9 100644 --- a/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs +++ b/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs @@ -1386,7 +1386,7 @@ mod protect { "SELECT pk, cluster, col1, col2, col3 FROM test_protect_keyspace.test_table", ); if let ResultValue::Varchar(value) = &result[0][2] { - assert_eq!( "I am gonna get encrypted!!", value); + assert_eq!("I am gonna get encrypted!!", value); //assert!(value.starts_with("{\"Ciphertext")); } else { panic!("expected 3rd column to be ResultValue::Varchar in {result:?}"); @@ -1402,8 +1402,8 @@ mod protect { assert_eq!(result[0][0], ResultValue::Varchar("pk1".into())); assert_eq!(result[0][1], ResultValue::Varchar("cluster".into())); if let ResultValue::Varchar(value) = &result[0][2] { - let re = Regex::new( r"^[0-9a-fA-F]+$").unwrap(); - assert!( re.is_match( value )); + let re = Regex::new(r"^[0-9a-fA-F]+$").unwrap(); + assert!(re.is_match(value)); } else { panic!("expected 3rd column to be ResultValue::Varchar in {result:?}"); } From fda49073861f9e459c168eddeee6b87fbc8d63b7 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Mon, 25 Apr 2022 06:26:24 +0100 Subject: [PATCH 35/60] updated cql3 version --- Cargo.lock | 18 +++++++++--------- shotover-proxy/Cargo.toml | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e389421bf..fb56081e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -54,9 +54,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.56" +version = "1.0.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4361135be9122e0870de935d7c439aef945b9f9ddd4199a553b5270b49c82a27" +checksum = "08f9b8508dccb7687a1d6c4ce66b2b0ecef467c94667de27d8d7fe1f8d2a9cdc" [[package]] name = "arrayref" @@ -356,9 +356,9 @@ dependencies = [ [[package]] name = "clap" -version = "3.1.10" +version = "3.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3124f3f75ce09e22d1410043e1e24f2ecc44fad3afe4f08408f1f7663d68da2b" +checksum = "7c167e37342afc5f33fd87bbc870cedd020d2a6dffa05d45ccd9241fbdd146db" dependencies = [ "atty", "bitflags", @@ -446,8 +446,8 @@ dependencies = [ [[package]] name = "cql3_parser" -version = "0.0.1" -source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=create_FQName#e6d5b9d352d0212ebbeccd427422efca1c792ea6" +version = "0.1.0" +source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#f654a7f8795ab98d279be966e9fde1352f7d753b" dependencies = [ "bigdecimal", "bytes", @@ -2549,7 +2549,7 @@ dependencies = [ "cached", "cassandra-cpp", "cassandra-protocol", - "clap 3.1.10", + "clap 3.1.12", "cql3_parser", "crc16", "criterion", @@ -3054,9 +3054,9 @@ dependencies = [ [[package]] name = "tracing-log" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6923477a48e41c1951f1999ef8bb5a3023eb723ceadafe78ffb65dc366761e3" +checksum = "78ddad33d2d10b1ed7eb9d1f518a5674713876e97e5bb9b7345a7984fbb4f922" dependencies = [ "env_logger", "lazy_static", diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml index 713c19e5b..4d92d0943 100644 --- a/shotover-proxy/Cargo.toml +++ b/shotover-proxy/Cargo.toml @@ -43,7 +43,7 @@ anyhow = "1.0.31" # Parsers sqlparser = "0.16" -cql3_parser = { git = "https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git", branch = "create_FQName" } +cql3_parser = { git = "https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git", branch = "main" } serde = { version = "1.0.111", features = ["derive"] } serde_json = "1.0" serde_yaml = "0.8.21" From 56950b73403092aa033f95bd37a0278f3c689342 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Mon, 25 Apr 2022 07:13:11 +0100 Subject: [PATCH 36/60] updated as per Connor review --- shotover-proxy/src/codec/cassandra.rs | 4 ++-- shotover-proxy/src/frame/cassandra.rs | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/shotover-proxy/src/codec/cassandra.rs b/shotover-proxy/src/codec/cassandra.rs index 53366ce45..8bda97aa1 100644 --- a/shotover-proxy/src/codec/cassandra.rs +++ b/shotover-proxy/src/codec/cassandra.rs @@ -366,7 +366,7 @@ mod cassandra_protocol_tests { tracing_id: None, warnings: vec![], operation: CassandraOperation::Query { - query: CQL::parse_from_string("Select * from system.local where key = 'local'"), + query: CQL::parse_from_string("SELECT FROM system.local WHERE key = 'local'"), params: QueryParams::default(), }, }))]; @@ -387,7 +387,7 @@ mod cassandra_protocol_tests { tracing_id: None, warnings: vec![], operation: CassandraOperation::Query { - query: CQL::parse_from_string("insert into system.foo (bar) values ('bar2')"), + query: CQL::parse_from_string("INSERT INTO system.foo (bar) VALUES ('bar2')"), params: QueryParams::default(), }, }))]; diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index 14d17b099..d1dc0d414 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -1,3 +1,5 @@ +use crate::message::QueryType::PubSubMessage; +use crate::message::{MessageValue, QueryType}; use anyhow::{anyhow, Result}; use bytes::Bytes; use cassandra_protocol::compression::Compression; @@ -37,9 +39,6 @@ use std::str::FromStr; use tracing::debug; use uuid::Uuid; -use crate::message::QueryType::PubSubMessage; -use crate::message::{MessageValue, QueryType}; - /// Extract the length of a BATCH statement (count of requests) from the body bytes fn get_batch_len(bytes: &[u8]) -> Result { let len = bytes.len(); From 1180075aa11e5ce252adea332fe1b96610abcbe1 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Mon, 25 Apr 2022 07:16:28 +0100 Subject: [PATCH 37/60] updated as per Connor review --- Cargo.lock | 2 +- shotover-proxy/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index fb56081e3..3ccdfd9f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -447,7 +447,7 @@ dependencies = [ [[package]] name = "cql3_parser" version = "0.1.0" -source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=main#f654a7f8795ab98d279be966e9fde1352f7d753b" +source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git#f654a7f8795ab98d279be966e9fde1352f7d753b" dependencies = [ "bigdecimal", "bytes", diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml index 4d92d0943..f8249c012 100644 --- a/shotover-proxy/Cargo.toml +++ b/shotover-proxy/Cargo.toml @@ -43,7 +43,7 @@ anyhow = "1.0.31" # Parsers sqlparser = "0.16" -cql3_parser = { git = "https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git", branch = "main" } +cql3_parser = { git = "https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git" } serde = { version = "1.0.111", features = ["derive"] } serde_json = "1.0" serde_yaml = "0.8.21" From 65f58a822ab152160ac6f2ada336ecf6820e6300 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Mon, 25 Apr 2022 07:25:30 +0100 Subject: [PATCH 38/60] fixed typo --- shotover-proxy/src/codec/cassandra.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shotover-proxy/src/codec/cassandra.rs b/shotover-proxy/src/codec/cassandra.rs index 8bda97aa1..f13633722 100644 --- a/shotover-proxy/src/codec/cassandra.rs +++ b/shotover-proxy/src/codec/cassandra.rs @@ -366,7 +366,7 @@ mod cassandra_protocol_tests { tracing_id: None, warnings: vec![], operation: CassandraOperation::Query { - query: CQL::parse_from_string("SELECT FROM system.local WHERE key = 'local'"), + query: CQL::parse_from_string("SELECT * FROM system.local WHERE key = 'local'"), params: QueryParams::default(), }, }))]; From 31416e017cea23dd604c4ab44a362d489418a9db Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Tue, 26 Apr 2022 08:39:26 +0100 Subject: [PATCH 39/60] Changes as per Connor review --- shotover-proxy/src/frame/cassandra.rs | 163 +++++++------------ shotover-proxy/src/transforms/redis/cache.rs | 13 +- 2 files changed, 63 insertions(+), 113 deletions(-) diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index d1dc0d414..7fa4cb846 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -260,21 +260,21 @@ impl CassandraFrame { // set to lowest type let mut result = QueryType::SchemaChange; for cql_statement in &cql.statements { - result = match cql_statement.get_query_type() { - QueryType::ReadWrite => QueryType::ReadWrite, - QueryType::Write => match result { - QueryType::ReadWrite | QueryType::Write => result, - QueryType::Read => QueryType::ReadWrite, - QueryType::SchemaChange | PubSubMessage => QueryType::Write, - }, - QueryType::Read => { + result = match (cql_statement.get_query_type(), &result) { + (QueryType::ReadWrite, _) => QueryType::ReadWrite, + (QueryType::Write, QueryType::ReadWrite | QueryType::Write) => result, + (QueryType::Write, QueryType::Read) => QueryType::ReadWrite, + (QueryType::Write, QueryType::SchemaChange | PubSubMessage) => { + QueryType::Write + } + (QueryType::Read, _) => { if result == QueryType::SchemaChange { QueryType::Read } else { result } } - QueryType::SchemaChange | PubSubMessage => result, + (QueryType::SchemaChange | PubSubMessage, _) => result, } } result @@ -285,31 +285,27 @@ impl CassandraFrame { /// returns a list of table names from the CassandraOperation pub fn get_table_names(&self) -> Vec<&FQName> { - let mut result = vec![]; match &self.operation { - CassandraOperation::Query { query: cql, .. } => { - for cql_statement in &cql.statements { - if let Some(name) = CQLStatement::get_table_name(&cql_statement.statement) { - result.push(name); - } - } - } - CassandraOperation::Batch(batch) => { - for q in &batch.queries { - if let BatchStatementType::Statement(cql) = &q.ty { - for cql_statement in &cql.statements { - if let Some(name) = - CQLStatement::get_table_name(&cql_statement.statement) - { - result.push(name); - } - } - } - } - } - _ => {} + CassandraOperation::Query { query: cql, .. } => cql + .statements + .iter() + .filter_map(|stmt| CQLStatement::get_table_name(&stmt.statement)) + .collect(), + CassandraOperation::Batch(batch) => batch + .queries + .iter() + .filter_map(|batch_stmt| match &batch_stmt.ty { + BatchStatementType::Statement(cql) => Some(cql), + _ => None, + }) + .flat_map(|cql| { + cql.statements + .iter() + .filter_map(|stmt| CQLStatement::get_table_name(&stmt.statement)) + }) + .collect(), + _ => vec![], } - result } pub fn encode(self) -> RawCassandraFrame { @@ -349,45 +345,28 @@ pub enum CassandraOperation { impl CassandraOperation { /// Return all queries contained within CassandaOperation::Query and CassandraOperation::Batch - /// An Err is returned if the operation cannot contain queries or the queries failed to parse. /// /// TODO: This will return a custom iterator type when BATCH support is added pub fn queries(&mut self) -> Vec<&mut CassandraStatement> { - let mut result = vec![]; - /* - match self { - CassandraOperation::Query { query: cql, .. } => result.push( &mut *cql.statement), - // TODO: Return CassandraOperation::Batch queries once we add BATCH parsing to cassandra-protocol - _ => { } - } - */ if let CassandraOperation::Query { query: cql, .. } = self { - for cql_statement in &mut cql.statements { - result.push(&mut cql_statement.statement) - } + cql.statements + .iter_mut() + .map(|stmt| &mut stmt.statement) + .collect() + } else { + Vec::<&mut CassandraStatement>::new() } - result } /// Return all queries contained within CassandaOperation::Query and CassandraOperation::Batch - /// An Err is returned if the operation cannot contain queries or the queries failed to parse. /// /// TODO: This will return a custom iterator type when BATCH support is added pub fn get_cql_statements(&mut self) -> Vec<&mut Box> { - let mut result = vec![]; - /* - match self { - CassandraOperation::Query { query: cql, .. } => result.push( &mut *cql.statement), - // TODO: Return CassandraOperation::Batch queries once we add BATCH parsing to cassandra-protocol - _ => { } - } - */ if let CassandraOperation::Query { query: cql, .. } = self { - for cql_statement in &mut cql.statements { - result.push(cql_statement) - } + cql.statements.iter_mut().collect() + } else { + vec![] } - result } fn to_direction(&self) -> Direction { @@ -789,11 +768,6 @@ pub struct CQL { } impl CQL { - /// the number of statements in the CQL - pub fn get_statement_count(&self) -> usize { - self.statements.len() - } - fn from_value_and_col_spec(value: &Value, col_spec: &ColSpec) -> Operand { match value { Value::Some(vec) => { @@ -807,6 +781,9 @@ impl CQL { Value::NotSet => Operand::Null, } } + + /// Get the value of the parameter named `name` if it exists. Otherwise the name itself is returned as + /// a parameter Operand. fn set_param_value_by_name( name: &str, query_params: &QueryParams, @@ -861,19 +838,11 @@ impl CQL { param_types: &[ColSpec], ) -> Operand { match operand { - Operand::Tuple(vec) => { - let mut vec2 = Vec::with_capacity(vec.len()); - vec.iter().for_each(|o| { - vec2.push(CQL::set_operand_if_param( - o, - param_idx, - query_params, - param_types, - )) - }); - - Operand::Tuple(vec2) - } + Operand::Tuple(vec) => Operand::Tuple( + vec.iter() + .map(|o| CQL::set_operand_if_param(o, param_idx, query_params, param_types)) + .collect(), + ), Operand::Param(param_name) => { if param_name.starts_with('?') { CQL::set_param_value_by_position(param_idx, query_params, param_types) @@ -882,19 +851,11 @@ impl CQL { CQL::set_param_value_by_name(name, query_params, param_types) } } - Operand::Collection(vec) => { - let mut vec2 = Vec::with_capacity(vec.len()); - vec.iter().for_each(|o| { - vec2.push(CQL::set_operand_if_param( - o, - param_idx, - query_params, - param_types, - )) - }); - - Operand::Collection(vec2) - } + Operand::Collection(vec) => Operand::Collection( + vec.iter() + .map(|o| CQL::set_operand_if_param(o, param_idx, query_params, param_types)) + .collect(), + ), _ => operand.clone(), } } @@ -927,18 +888,18 @@ impl CQL { pub fn parse_from_string(cql_query_str: &str) -> Self { debug!("parse_from_string: {}", cql_query_str); let ast = CassandraAST::new(cql_query_str); - - let mut vec = Vec::with_capacity(ast.statements.len()); - - for statement in &ast.statements { - vec.push(Box::new(CQLStatement { - has_error: statement.0, - statement: statement.1.clone(), - })); - } CQL { has_error: ast.has_error(), - statements: vec, + statements: ast + .statements + .iter() + .map(|stmt| { + Box::new(CQLStatement { + has_error: stmt.0, + statement: stmt.1.clone(), + }) + }) + .collect(), } } } @@ -1156,7 +1117,7 @@ mod test { INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (3, 13, 'baz');"#; let cql = CQL::parse_from_string(query); - assert_eq!(3, cql.get_statement_count()); + assert_eq!(3, cql.statements.len()); assert!(!cql.has_error); } @@ -1167,7 +1128,7 @@ mod test { INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (3, 13, 'baz');"#; let cql = CQL::parse_from_string(query); - assert_eq!(3, cql.get_statement_count()); + assert_eq!(3, cql.statements.len()); assert!(cql.has_error); assert!(!cql.statements[0].has_error); assert!(cql.statements[1].has_error); diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index c2fa0ccf1..58097dc21 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -469,16 +469,6 @@ fn is_cacheable(cql_statement: &CQLStatement) -> CacheableState { CacheableState::Skip("Can not cache with ALLOW FILTERING".into()) } else if select.where_clause.is_empty() { CacheableState::Skip("Can not cache if where clause is empty".into()) - /* } else if !select.columns.is_empty() { - if select.columns.len() == 1 && select.columns[0].eq(&SelectElement::Star) { - CacheableState::Read - } else { - CacheableState::Skip( - "Can not cache if columns other than '*' are selected".into(), - ) - } - - */ } else { CacheableState::Read(table_name.into()) } @@ -530,8 +520,7 @@ fn build_query_redis_key_from_value_map( query_values: &BTreeMap>, table_name: &str, ) -> Result { - let mut key: Vec = vec![]; - key.extend(table_name.as_bytes()); + let mut key = table_name.as_bytes().to_vec(); for c_name in &table_cache_schema.partition_key { let column_name = c_name.to_lowercase(); debug!("processing partition key segment: {}", column_name); From 0d55dba9cc6c8375e377e4a3aaabeeb3f22c9157 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Tue, 26 Apr 2022 09:51:11 +0100 Subject: [PATCH 40/60] fixed broken merge --- shotover-proxy/src/frame/cassandra.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index c811d37b7..9f7e0d427 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -30,7 +30,6 @@ use cql3_parser::update::AssignmentOperator; use itertools::Itertools; use nonzero_ext::nonzero; use sodiumoxide::hex; -use std::convert::TryInto; use std::fmt::{Display, Formatter}; use std::io::Cursor; use std::net::IpAddr; @@ -38,7 +37,6 @@ use std::num::NonZeroU32; use std::str::FromStr; use tracing::debug; use uuid::Uuid; -use crate::message::{MessageValue, QueryType}; /// Functions for operations on an unparsed Cassandra frame pub mod raw_frame { From 2faf9342ceb950fa953472e58c71d8246e7c3599 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Wed, 27 Apr 2022 08:51:15 +0100 Subject: [PATCH 41/60] Changes as per Connor and Lucas --- shotover-proxy/Cargo.toml | 1 + shotover-proxy/src/transforms/protect/mod.rs | 11 ++--------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml index f8249c012..8bcad6c11 100644 --- a/shotover-proxy/Cargo.toml +++ b/shotover-proxy/Cargo.toml @@ -9,6 +9,7 @@ license = "Apache-2.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] +## Uncomment the following line to force IDE's that resist configuration to run all tests #default = ["alpha-transforms"] # Include WIP alpha transforms in the public API alpha-transforms = [] diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index c2c1e3d5a..d56663ebe 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -40,21 +40,14 @@ impl Protect { fn encode(protected: &Protected) -> Result> { match protected { Protected::Plaintext(_) => Err(anyhow!("can not encode plain text")), - Protected::Ciphertext { .. } => match serde_json::to_vec(protected) { - Ok(data) => Ok(data), - Err(e) => Err(anyhow!("{:?}", e)), - }, + Protected::Ciphertext { .. } => Ok(serde_json::to_vec(protected)?), } } /// decodes a byte array into the Protected object. This is here to centeralize the serde for /// the Protected object. fn decode(data: &[u8]) -> Result { - let result = serde_json::from_slice(data); - match result { - Ok(decoded) => Ok(decoded), - Err(e) => Err(anyhow!("{:?}", e)), - } + Ok(serde_json::from_slice(data)?) } /// get the list of protected columns for the specified table name. Will return `None` if no columns From 4ed6ff798f749de05334f2fee8a89a553276f524 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Thu, 28 Apr 2022 07:18:31 +0100 Subject: [PATCH 42/60] changes as per Lucas --- shotover-proxy/tests/helpers/cassandra.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/shotover-proxy/tests/helpers/cassandra.rs b/shotover-proxy/tests/helpers/cassandra.rs index f73746ee2..f664fb54d 100644 --- a/shotover-proxy/tests/helpers/cassandra.rs +++ b/shotover-proxy/tests/helpers/cassandra.rs @@ -12,10 +12,9 @@ pub fn cassandra_connection(contact_points: &str, port: u16) -> Session { cluster.set_port(port).ok(); cluster.set_load_balance_round_robin(); let result = cluster.connect(); - if let Some(err) = &result.as_ref().err() { - panic!("{}", err); - } - result.unwrap() + // By default unwrap uses the Debug formatter `{:?}` which is extremely noisy for the error type returned by `connect()`. + // So we instead force the Display formatter `{}` on the error. + cluster.connect().map_err(|err| format!("{err}")).unwrap() } #[derive(Debug, Clone, PartialOrd, PartialEq, Eq, Ord)] From 22e76c1c76a420d81f7c68fe4ffd7098b19f77a6 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Thu, 28 Apr 2022 07:33:13 +0100 Subject: [PATCH 43/60] fixed compile issue --- shotover-proxy/tests/helpers/cassandra.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/shotover-proxy/tests/helpers/cassandra.rs b/shotover-proxy/tests/helpers/cassandra.rs index f664fb54d..7e274cd92 100644 --- a/shotover-proxy/tests/helpers/cassandra.rs +++ b/shotover-proxy/tests/helpers/cassandra.rs @@ -11,7 +11,6 @@ pub fn cassandra_connection(contact_points: &str, port: u16) -> Session { cluster.set_credentials("cassandra", "cassandra").unwrap(); cluster.set_port(port).ok(); cluster.set_load_balance_round_robin(); - let result = cluster.connect(); // By default unwrap uses the Debug formatter `{:?}` which is extremely noisy for the error type returned by `connect()`. // So we instead force the Display formatter `{}` on the error. cluster.connect().map_err(|err| format!("{err}")).unwrap() From 95a25ef712a03b2ead73baca6ecb70a71c8abaa7 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Thu, 28 Apr 2022 07:37:22 +0100 Subject: [PATCH 44/60] changed to shotover repo --- shotover-proxy/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml index 8bcad6c11..b3fc33c29 100644 --- a/shotover-proxy/Cargo.toml +++ b/shotover-proxy/Cargo.toml @@ -44,7 +44,7 @@ anyhow = "1.0.31" # Parsers sqlparser = "0.16" -cql3_parser = { git = "https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git" } +cql3_parser = { git = "https://github.com/shotover/rust_cql3_parser.git" } serde = { version = "1.0.111", features = ["derive"] } serde_json = "1.0" serde_yaml = "0.8.21" From d424c6df2be9f4329763528c88cdb6216969436c Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Thu, 28 Apr 2022 07:55:16 +0100 Subject: [PATCH 45/60] updated Cargo.lock added logic documentation --- Cargo.lock | 2 +- shotover-proxy/src/frame/cassandra.rs | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 3ccdfd9f7..81a1f1221 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -447,7 +447,7 @@ dependencies = [ [[package]] name = "cql3_parser" version = "0.1.0" -source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git#f654a7f8795ab98d279be966e9fde1352f7d753b" +source = "git+https://github.com/shotover/rust_cql3_parser.git#f654a7f8795ab98d279be966e9fde1352f7d753b" dependencies = [ "bigdecimal", "bytes", diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index 9f7e0d427..02e7a4f9a 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -799,6 +799,11 @@ impl CQL { param_types: &[ColSpec], ) -> Operand { if let Some(QueryValues::NamedValues(value_map)) = &query_params.values { + /* + this code block first uses the hash table to determine if there is a value for the name. + then, only if there is, does it do the longer iteration over the value map looking for the + name to extract the position which is then used to index the proper param_type. + */ if let Some(value) = value_map.get(name) { if let Some(idx) = value_map .iter() From ab61046ac061b7242874320cbbf01a3c4d5dcf28 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Fri, 29 Apr 2022 07:46:12 +0100 Subject: [PATCH 46/60] removed boxing in CQL structure --- shotover-proxy/src/frame/cassandra.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index 02e7a4f9a..2cdbb0174 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -370,7 +370,7 @@ impl CassandraOperation { /// Return all queries contained within CassandaOperation::Query and CassandraOperation::Batch /// /// TODO: This will return a custom iterator type when BATCH support is added - pub fn get_cql_statements(&mut self) -> Vec<&mut Box> { + pub fn get_cql_statements(&mut self) -> Vec<&mut CQLStatement> { if let CassandraOperation::Query { query: cql, .. } = self { cql.statements.iter_mut().collect() } else { @@ -772,7 +772,7 @@ impl Display for CQLStatement { #[derive(PartialEq, Debug, Clone)] pub struct CQL { - pub statements: Vec>, + pub statements: Vec, pub(crate) has_error: bool, } @@ -907,11 +907,9 @@ impl CQL { statements: ast .statements .iter() - .map(|stmt| { - Box::new(CQLStatement { - has_error: stmt.0, - statement: stmt.1.clone(), - }) + .map(|stmt| CQLStatement { + has_error: stmt.0, + statement: stmt.1.clone(), }) .collect(), } From 000579c263b78c7f2c5cd87fef913dbdbdd420a3 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Fri, 29 Apr 2022 08:13:21 +0100 Subject: [PATCH 47/60] Changed to cql3-parser 0.1.0 and commented out sqlparser --- Cargo.lock | 20 ++++++-------------- shotover-proxy/Cargo.toml | 4 ++-- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 81a1f1221..c4a4773e1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -445,9 +445,10 @@ dependencies = [ ] [[package]] -name = "cql3_parser" +name = "cql3-parser" version = "0.1.0" -source = "git+https://github.com/shotover/rust_cql3_parser.git#f654a7f8795ab98d279be966e9fde1352f7d753b" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed1aad3b971b15c93eeb793834fd3ef93c79965c6a0244473811fc57b76e645" dependencies = [ "bigdecimal", "bytes", @@ -2550,7 +2551,7 @@ dependencies = [ "cassandra-cpp", "cassandra-protocol", "clap 3.1.12", - "cql3_parser", + "cql3-parser", "crc16", "criterion", "csv", @@ -2588,7 +2589,6 @@ dependencies = [ "serde_yaml", "serial_test", "sodiumoxide", - "sqlparser", "strum_macros", "test-helpers", "thiserror", @@ -2684,15 +2684,6 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" -[[package]] -name = "sqlparser" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e9a527b68048eb95495a1508f6c8395c8defcff5ecdbe8ad4106d08a2ef2a3c" -dependencies = [ - "log", -] - [[package]] name = "static_assertions" version = "1.1.0" @@ -3095,7 +3086,8 @@ dependencies = [ [[package]] name = "tree-sitter-cql" version = "0.0.1" -source = "git+https://github.com/Claude-at-Instaclustr/tree-sitter-cql?branch=main#ead1375c83892111d31540e6d8d0dc563f187939" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fdcd73bf0389e82b70592060b7e25195d8d1728d7a0b76b549e92f7f5d124b4" dependencies = [ "cc", "regex", diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml index b3fc33c29..11c4d9192 100644 --- a/shotover-proxy/Cargo.toml +++ b/shotover-proxy/Cargo.toml @@ -43,8 +43,8 @@ thiserror = "1.0" anyhow = "1.0.31" # Parsers -sqlparser = "0.16" -cql3_parser = { git = "https://github.com/shotover/rust_cql3_parser.git" } +#sqlparser = "0.16" +cql3-parser = "0.1.0" serde = { version = "1.0.111", features = ["derive"] } serde_json = "1.0" serde_yaml = "0.8.21" From 684299848df235aab3dde061a0d5ad91da557ffb Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Tue, 3 May 2022 08:23:00 +0100 Subject: [PATCH 48/60] Updated as per code review --- .../src/transforms/cassandra/peers_rewrite.rs | 78 +++++++++---------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs index 60770e3a3..7d2cc1db6 100644 --- a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs @@ -1,4 +1,3 @@ -use crate::frame::cassandra::CQLStatement; use crate::frame::{CassandraOperation, CassandraResult, Frame}; use crate::message::{IntSize, Message, MessageValue}; use crate::{ @@ -8,6 +7,7 @@ use crate::{ use anyhow::Result; use async_trait::async_trait; use cql3_parser::cassandra_statement::CassandraStatement; +use cql3_parser::common::FQName; use cql3_parser::select::SelectElement; use serde::Deserialize; use std::collections::HashMap; @@ -21,6 +21,7 @@ impl CassandraPeersRewriteConfig { pub async fn get_transform(&self) -> Result { Ok(Transforms::CassandraPeersRewrite(CassandraPeersRewrite { port: self.port, + peer_table: FQName::new("system", "peers_v2"), })) } } @@ -28,6 +29,7 @@ impl CassandraPeersRewriteConfig { #[derive(Clone)] pub struct CassandraPeersRewrite { port: u32, + peer_table: FQName, } #[async_trait] @@ -35,27 +37,19 @@ impl Transform for CassandraPeersRewrite { async fn transform<'a>(&'a mut self, mut message_wrapper: Wrapper<'a>) -> ChainResponse { // Find the indices of queries to system.peers & system.peers_v2 // we need to know which columns in which CQL queries in which messages have system peers - let mut column_names: HashMap> = HashMap::new(); - - message_wrapper + let column_names: HashMap> = message_wrapper .messages .iter_mut() .enumerate() .filter_map(|(i, m)| { - let sys_peers = extract_native_port_column(m); + let sys_peers = extract_native_port_column(&self.peer_table, m); if sys_peers.is_empty() { None } else { Some((i, sys_peers)) } }) - .for_each(|(k, mut v)| { - if let Some(x) = column_names.get_mut(&k) { - x.append(&mut v); - } else { - column_names.insert(k, v); - } - }); + .collect(); let mut response = message_wrapper.call_next_transform().await?; @@ -68,29 +62,27 @@ impl Transform for CassandraPeersRewrite { } /// determine if the message contains a SELECT from `system.peers_v2` that includes the `native_port` column -/// return a list of (statement index, column index) pairs -fn extract_native_port_column(message: &mut Message) -> Vec { +/// return a list of column names (or their alias) for each `native_port`. +fn extract_native_port_column(peer_table: &FQName, message: &mut Message) -> Vec { let mut result: Vec = vec![]; if let Some(Frame::Cassandra(cassandra)) = message.frame() { if let CassandraOperation::Query { query, .. } = &cassandra.operation { for cql_statement in &query.statements { let statement = &cql_statement.statement; if let CassandraStatement::Select(select) = &statement { - if let Some(table_name) = CQLStatement::get_table_name(statement) { - if table_name.to_string().eq("system.peers_v2") { - select - .columns - .iter() - .for_each(|select_element| match select_element { - SelectElement::Column(col_name) => { - if col_name.name.eq("native_port") { - result.push(col_name.alias_or_name()); - } + if peer_table.eq(&select.table_name) { + select + .columns + .iter() + .for_each(|select_element| match select_element { + SelectElement::Column(col_name) => { + if col_name.name == "native_port" { + result.push(col_name.alias_or_name()); } - SelectElement::Star => result.push("native_port".to_string()), - _ => {} - }); - } + } + SelectElement::Star => result.push("native_port".to_string()), + _ => {} + }); } } } @@ -204,25 +196,33 @@ mod test { #[test] fn test_is_system_peers_v2() { - let v = - extract_native_port_column(&mut create_query_message("SELECT * FROM system.peers_v2;")); + let peer_table = FQName::new("system", "peers_v2"); + let v = extract_native_port_column( + &peer_table, + &mut create_query_message("SELECT * FROM system.peers_v2;"), + ); assert_eq!(1, v.len()); assert_eq!("native_port", v[0]); - let v = extract_native_port_column(&mut create_query_message( - "SELECT * FROM not_system.peers_v2;", - )); + let v = extract_native_port_column( + &peer_table, + &mut create_query_message("SELECT * FROM not_system.peers_v2;"), + ); assert!(v.is_empty()); - let v = extract_native_port_column(&mut create_query_message( - "SELECT native_port as foo from system.peers_v2", - )); + let v = extract_native_port_column( + &peer_table, + &mut create_query_message("SELECT native_port as foo from system.peers_v2"), + ); assert_eq!(1, v.len()); assert_eq!("foo", v[0]); - let v = extract_native_port_column(&mut create_query_message( - "SELECT native_port as foo, native_port from system.peers_v2", - )); + let v = extract_native_port_column( + &peer_table, + &mut create_query_message( + "SELECT native_port as foo, native_port from system.peers_v2", + ), + ); assert_eq!(2, v.len()); assert_eq!("foo", v[0]); assert_eq!("native_port", v[1]); From 83b05e6cd626f289e4dab9f0a0c6353af57bd356 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Wed, 4 May 2022 08:59:43 +0100 Subject: [PATCH 49/60] Changes as per Lucas --- shotover-proxy/Cargo.toml | 1 - shotover-proxy/src/transforms/protect/mod.rs | 11 ++--------- shotover-proxy/src/transforms/redis/cache.rs | 4 ++-- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml index 11c4d9192..13a3df065 100644 --- a/shotover-proxy/Cargo.toml +++ b/shotover-proxy/Cargo.toml @@ -43,7 +43,6 @@ thiserror = "1.0" anyhow = "1.0.31" # Parsers -#sqlparser = "0.16" cql3-parser = "0.1.0" serde = { version = "1.0.111", features = ["derive"] } serde_json = "1.0" diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index d56663ebe..bec7f12cc 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -299,12 +299,7 @@ impl Protected { let sym_key = key_management .cached_get_key(key_id.to_string(), Some(enc_dek), Some(kek_id)) .await?; - let result = decrypt(cipher, nonce, &sym_key.plaintext); - if result.is_err() { - Err(anyhow!("{}", result.err().unwrap())) - } else { - Ok(result.unwrap()) - } + decrypt(cipher, nonce, &sym_key.plaintext) } } } @@ -439,9 +434,7 @@ mod test { #[tokio::test(flavor = "multi_thread")] //#[test] async fn round_trip_test() { - if sodiumoxide::init().is_err() { - panic!("could not init sodiumoxide"); - } + sodiumoxide::init().expect("could not init sodiumoxide"); // verify low level round trip works. let kek = sodiumoxide::crypto::secretbox::xsalsa20poly1305::gen_key(); diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index 58097dc21..434af6f01 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -438,8 +438,8 @@ impl SimpleRedisCache { "clientdetailstodo".to_string(), ) .await; - if result.is_err() { - warn!("Cache error: {}", result.err().unwrap()); + if let Err(err) = result { + warn!("Cache error: {}", err); } } } From 098caecb1eb694eb2c3437f80e4ff432ee4fd8ab Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Wed, 4 May 2022 11:41:48 +0100 Subject: [PATCH 50/60] changes as per Lucas --- shotover-proxy/src/frame/cassandra.rs | 262 ++---------------- .../src/transforms/cassandra/peers_rewrite.rs | 21 +- shotover-proxy/src/transforms/redis/cache.rs | 2 +- .../cassandra_int_tests/basic_driver_tests.rs | 5 +- 4 files changed, 39 insertions(+), 251 deletions(-) diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index 2cdbb0174..c589a5b91 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -1,4 +1,3 @@ -use crate::message::QueryType::PubSubMessage; use crate::message::{MessageValue, QueryType}; use anyhow::{anyhow, Result}; use bytes::Bytes; @@ -11,7 +10,7 @@ use cassandra_protocol::frame::frame_query::BodyReqQuery; use cassandra_protocol::frame::frame_request::RequestBody; use cassandra_protocol::frame::frame_response::ResponseBody; use cassandra_protocol::frame::frame_result::{ - BodyResResultPrepared, BodyResResultRows, BodyResResultSetKeyspace, ColSpec, ResResultBody, + BodyResResultPrepared, BodyResResultRows, BodyResResultSetKeyspace, ResResultBody, RowsMetadata, RowsMetadataFlags, }; use cassandra_protocol::frame::{ @@ -20,7 +19,6 @@ use cassandra_protocol::frame::{ use cassandra_protocol::query::{QueryParams, QueryValues}; use cassandra_protocol::types::blob::Blob; use cassandra_protocol::types::cassandra_type::CassandraType; -use cassandra_protocol::types::value::Value; use cassandra_protocol::types::{CBytes, CBytesShort, CInt, CLong}; use cql3_parser::cassandra_ast::CassandraAST; use cql3_parser::cassandra_statement::CassandraStatement; @@ -256,34 +254,40 @@ impl CassandraFrame { } /// returns the query type for the current statement. + /// Query type is calculated by scanning the query types of the enclosed statements with the + /// highest valued result being returned. + /// + /// Statements, in descending order, are: + /// + /// * ReadWrite + /// * Write + /// * Read + /// * SchemaChange + /// * PubSubMessage pub fn get_query_type(&self) -> QueryType { - /* - Read, - Write, - ReadWrite, - SchemaChange, - PubSubMessage, - */ match &self.operation { CassandraOperation::Query { query: cql, .. } => { // set to lowest type - let mut result = QueryType::SchemaChange; + let mut result = QueryType::PubSubMessage; for cql_statement in &cql.statements { result = match (cql_statement.get_query_type(), &result) { (QueryType::ReadWrite, _) => QueryType::ReadWrite, - (QueryType::Write, QueryType::ReadWrite | QueryType::Write) => result, - (QueryType::Write, QueryType::Read) => QueryType::ReadWrite, - (QueryType::Write, QueryType::SchemaChange | PubSubMessage) => { + (QueryType::Write, QueryType::Read | QueryType::ReadWrite) => { + QueryType::ReadWrite + } + (QueryType::Write, QueryType::SchemaChange | QueryType::PubSubMessage) => { QueryType::Write } - (QueryType::Read, _) => { - if result == QueryType::SchemaChange { - QueryType::Read - } else { - result - } + (QueryType::Read, QueryType::ReadWrite | QueryType::Write) => { + QueryType::ReadWrite + } + (QueryType::Read, QueryType::SchemaChange | QueryType::PubSubMessage) => { + QueryType::Read + } + (QueryType::SchemaChange, QueryType::PubSubMessage) => { + QueryType::SchemaChange } - (QueryType::SchemaChange | PubSubMessage, _) => result, + _ => result, } } result @@ -518,10 +522,6 @@ impl CQLStatement { } } - pub fn is_apply_batch(&self) -> bool { - matches!(&self.statement, CassandraStatement::ApplyBatch) - } - /// returns the query type for the current statement. pub fn get_query_type(&self) -> QueryType { /* @@ -591,99 +591,6 @@ impl CQLStatement { } } - /// replaces the Operand::Param objects with Operand::Const objects where the parameters are defined in the - /// QueryParameters. - /// This method makes a copy of the CassandraStatement - pub fn set_param_values( - &self, - params: &QueryParams, - param_types: &[ColSpec], - ) -> CassandraStatement { - let mut param_idx: usize = 0; - let mut statement = self.statement.clone(); - match &mut statement { - CassandraStatement::Delete(delete) => { - CQL::set_relation_elements_values( - &mut param_idx, - params, - param_types, - &mut delete.where_clause, - ); - CQL::set_relation_elements_values( - &mut param_idx, - params, - param_types, - &mut delete.if_clause, - ); - } - CassandraStatement::Insert(insert) => { - if let InsertValues::Values(operands) = &mut insert.values { - for operand in operands { - *operand = - CQL::set_operand_if_param(operand, &mut param_idx, params, param_types) - } - } - } - CassandraStatement::Select(select) => { - CQL::set_relation_elements_values( - &mut param_idx, - params, - param_types, - &mut select.where_clause, - ); - } - CassandraStatement::Update(update) => { - for assignment_idx in 0..update.assignments.len() { - let mut assignment_element = &mut update.assignments[assignment_idx]; - assignment_element.value = CQL::set_operand_if_param( - &assignment_element.value, - &mut param_idx, - params, - param_types, - ); - if let Some(assignment_operator) = &assignment_element.operator { - match assignment_operator { - AssignmentOperator::Plus(operand) => { - assignment_element.operator = Option::from( - AssignmentOperator::Plus(CQL::set_operand_if_param( - operand, - &mut param_idx, - params, - param_types, - )), - ); - } - AssignmentOperator::Minus(operand) => { - assignment_element.operator = Option::from( - AssignmentOperator::Minus(CQL::set_operand_if_param( - operand, - &mut param_idx, - params, - param_types, - )), - ); - } - } - } - } - CQL::set_relation_elements_values( - &mut param_idx, - params, - param_types, - &mut update.where_clause, - ); - CQL::set_relation_elements_values( - &mut param_idx, - params, - param_types, - &mut update.if_clause, - ); - } - _ => {} - } - statement - } - fn has_params_in_operand(operand: &Operand) -> bool { match operand { Operand::Tuple(vec) | Operand::Collection(vec) => { @@ -773,121 +680,11 @@ impl Display for CQLStatement { #[derive(PartialEq, Debug, Clone)] pub struct CQL { pub statements: Vec, - pub(crate) has_error: bool, } impl CQL { - fn from_value_and_col_spec(value: &Value, col_spec: &ColSpec) -> Operand { - match value { - Value::Some(vec) => { - let cbytes = CBytes::new(vec.clone()); - let message_value = - MessageValue::build_value_from_cstar_col_type(col_spec, &cbytes); - let pmsg_value = &message_value; - pmsg_value.into() - } - Value::Null => Operand::Null, - Value::NotSet => Operand::Null, - } - } - - /// Get the value of the parameter named `name` if it exists. Otherwise the name itself is returned as - /// a parameter Operand. - fn set_param_value_by_name( - name: &str, - query_params: &QueryParams, - param_types: &[ColSpec], - ) -> Operand { - if let Some(QueryValues::NamedValues(value_map)) = &query_params.values { - /* - this code block first uses the hash table to determine if there is a value for the name. - then, only if there is, does it do the longer iteration over the value map looking for the - name to extract the position which is then used to index the proper param_type. - */ - if let Some(value) = value_map.get(name) { - if let Some(idx) = value_map - .iter() - .enumerate() - .filter_map( - |(idx, (key, _value))| { - if key.eq(name) { - Some(idx) - } else { - None - } - }, - ) - .next() - { - return CQL::from_value_and_col_spec(value, ¶m_types[idx]); - } - } - } - Operand::Param(format!(":{}", name)) - } - - fn set_param_value_by_position( - param_idx: &mut usize, - query_params: &QueryParams, - param_types: &[ColSpec], - ) -> Operand { - if let Some(QueryValues::SimpleValues(values)) = &query_params.values { - if let Some(value) = values.get(*param_idx) { - *param_idx += 1; - CQL::from_value_and_col_spec(value, ¶m_types[*param_idx]) - } else { - *param_idx += 1; - Operand::Param("?".into()) - } - } else { - *param_idx += 1; - Operand::Param("?".into()) - } - } - - fn set_operand_if_param( - operand: &Operand, - param_idx: &mut usize, - query_params: &QueryParams, - param_types: &[ColSpec], - ) -> Operand { - match operand { - Operand::Tuple(vec) => Operand::Tuple( - vec.iter() - .map(|o| CQL::set_operand_if_param(o, param_idx, query_params, param_types)) - .collect(), - ), - Operand::Param(param_name) => { - if param_name.starts_with('?') { - CQL::set_param_value_by_position(param_idx, query_params, param_types) - } else { - let name = param_name.split_at(0).1; - CQL::set_param_value_by_name(name, query_params, param_types) - } - } - Operand::Collection(vec) => Operand::Collection( - vec.iter() - .map(|o| CQL::set_operand_if_param(o, param_idx, query_params, param_types)) - .collect(), - ), - _ => operand.clone(), - } - } - - fn set_relation_elements_values( - param_idx: &mut usize, - query_params: &QueryParams, - param_types: &[ColSpec], - where_clause: &mut [RelationElement], - ) { - for relation_element in where_clause { - relation_element.value = CQL::set_operand_if_param( - &relation_element.value, - param_idx, - query_params, - param_types, - ); - } + pub fn has_error(&self) -> bool { + self.statements.iter().any(|s| s.has_error) } pub fn to_query_string(&self) -> String { @@ -903,7 +700,6 @@ impl CQL { debug!("parse_from_string: {}", cql_query_str); let ast = CassandraAST::new(cql_query_str); CQL { - has_error: ast.has_error(), statements: ast .statements .iter() @@ -1130,7 +926,7 @@ mod test { let cql = CQL::parse_from_string(query); assert_eq!(3, cql.statements.len()); - assert!(!cql.has_error); + assert!(!cql.has_error()); } #[test] @@ -1141,7 +937,7 @@ mod test { let cql = CQL::parse_from_string(query); assert_eq!(3, cql.statements.len()); - assert!(cql.has_error); + assert!(cql.has_error()); assert!(!cql.statements[0].has_error); assert!(cql.statements[1].has_error); assert!(!cql.statements[2].has_error); diff --git a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs index 7d2cc1db6..02732dd8d 100644 --- a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs @@ -70,11 +70,9 @@ fn extract_native_port_column(peer_table: &FQName, message: &mut Message) -> Vec for cql_statement in &query.statements { let statement = &cql_statement.statement; if let CassandraStatement::Select(select) = &statement { - if peer_table.eq(&select.table_name) { - select - .columns - .iter() - .for_each(|select_element| match select_element { + if peer_table == &select.table_name { + for select_element in &select.columns { + match select_element { SelectElement::Column(col_name) => { if col_name.name == "native_port" { result.push(col_name.alias_or_name()); @@ -82,7 +80,8 @@ fn extract_native_port_column(peer_table: &FQName, message: &mut Message) -> Vec } SelectElement::Star => result.push("native_port".to_string()), _ => {} - }); + } + } } } } @@ -201,8 +200,7 @@ mod test { &peer_table, &mut create_query_message("SELECT * FROM system.peers_v2;"), ); - assert_eq!(1, v.len()); - assert_eq!("native_port", v[0]); + assert_eq!(vec!("native_port".to_string()), v); let v = extract_native_port_column( &peer_table, @@ -214,8 +212,7 @@ mod test { &peer_table, &mut create_query_message("SELECT native_port as foo from system.peers_v2"), ); - assert_eq!(1, v.len()); - assert_eq!("foo", v[0]); + assert_eq!(vec!("foo".to_string()), v); let v = extract_native_port_column( &peer_table, @@ -223,9 +220,7 @@ mod test { "SELECT native_port as foo, native_port from system.peers_v2", ), ); - assert_eq!(2, v.len()); - assert_eq!("foo", v[0]); - assert_eq!("native_port", v[1]); + assert_eq!(vec!["foo".to_string(), "native_port".to_string()], v); } #[test] diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index 434af6f01..bcd612243 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -791,7 +791,7 @@ mod test { fn build_query(query_string: &str) -> CassandraStatement { let cql = CQL::parse_from_string(query_string); - assert!(!cql.has_error); + assert!(!cql.has_error()); cql.statements[0].statement.clone() } diff --git a/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs b/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs index 4556aeba9..1e3933294 100644 --- a/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs +++ b/shotover-proxy/tests/cassandra_int_tests/basic_driver_tests.rs @@ -1598,10 +1598,7 @@ fn test_source_tls_and_single_tls() { fn test_cassandra_redis_cache() { let recorder = DebuggingRecorder::new(); let snapshotter = recorder.snapshotter(); - let result = recorder.install(); - if result.is_err() { - panic!("{:?}", result.err()); - } + recorder.install().unwrap(); let _compose = DockerCompose::new("example-configs/cassandra-redis-cache/docker-compose.yml"); let shotover_manager = ShotoverManager::from_topology_file_without_observability( From 7a1614c0766ab99194f0a8a2fa1ad47d09877989 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Thu, 5 May 2022 06:35:13 +0100 Subject: [PATCH 51/60] changes requested by Lucas --- shotover-proxy/src/frame/cassandra.rs | 30 +++++++-------------------- 1 file changed, 7 insertions(+), 23 deletions(-) diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index c589a5b91..968e8d4f8 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -297,28 +297,12 @@ impl CassandraFrame { } /// returns a list of table names from the CassandraOperation - pub fn get_table_names(&self) -> Vec<&FQName> { - match &self.operation { - CassandraOperation::Query { query: cql, .. } => cql - .statements - .iter() - .filter_map(|stmt| CQLStatement::get_table_name(&stmt.statement)) - .collect(), - CassandraOperation::Batch(batch) => batch - .queries - .iter() - .filter_map(|batch_stmt| match &batch_stmt.ty { - BatchStatementType::Statement(cql) => Some(cql), - _ => None, - }) - .flat_map(|cql| { - cql.statements - .iter() - .filter_map(|stmt| CQLStatement::get_table_name(&stmt.statement)) - }) - .collect(), - _ => vec![], - } + pub fn get_table_names(&mut self) -> Vec<&FQName> { + self.operation + .queries() + .into_iter() + .filter_map(|stmt| CQLStatement::get_table_name(stmt)) + .collect() } pub fn encode(self) -> RawCassandraFrame { @@ -687,7 +671,7 @@ impl CQL { self.statements.iter().any(|s| s.has_error) } - pub fn to_query_string(&self) -> String { + fn to_query_string(&self) -> String { self.statements .iter() .map(|c| c.statement.to_string()) From e09cf520014e07ed64ac65984a5733bd79445d33 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Mon, 9 May 2022 08:37:42 +0100 Subject: [PATCH 52/60] fixed string issue --- shotover-proxy/src/transforms/cassandra/peers_rewrite.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs index 02732dd8d..518cca83c 100644 --- a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs @@ -75,7 +75,7 @@ fn extract_native_port_column(peer_table: &FQName, message: &mut Message) -> Vec match select_element { SelectElement::Column(col_name) => { if col_name.name == "native_port" { - result.push(col_name.alias_or_name()); + result.push(col_name.alias_or_name().to_string()); } } SelectElement::Star => result.push("native_port".to_string()), From 8b69aa947c562824227941d821dd9047c1032de8 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Mon, 9 May 2022 11:17:22 +0100 Subject: [PATCH 53/60] removed default feature from cargo --- shotover-proxy/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml index 345bbe184..a352e29a4 100644 --- a/shotover-proxy/Cargo.toml +++ b/shotover-proxy/Cargo.toml @@ -10,7 +10,7 @@ license = "Apache-2.0" [features] ## Uncomment the following line to force IDE's that resist configuration to run all tests -default = ["alpha-transforms"] +#default = ["alpha-transforms"] # Include WIP alpha transforms in the public API alpha-transforms = [] From ce41cebd73364a8bfc48299088b5a2b47c97ddd1 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Tue, 10 May 2022 08:08:06 +0100 Subject: [PATCH 54/60] converted parse error to Unknown() --- shotover-proxy/src/frame/cassandra.rs | 45 ++++++++++++++------ shotover-proxy/src/transforms/redis/cache.rs | 4 -- 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index 35ffb70ed..bef27fc79 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -507,7 +507,6 @@ impl CassandraOperation { #[derive(PartialEq, Debug, Clone)] pub struct CQLStatement { pub statement: CassandraStatement, - pub has_error: bool, } impl CQLStatement { @@ -681,10 +680,6 @@ pub struct CQL { } impl CQL { - pub fn has_error(&self) -> bool { - self.statements.iter().any(|s| s.has_error) - } - fn to_query_string(&self) -> String { self.statements .iter() @@ -697,13 +692,21 @@ impl CQL { pub fn parse_from_string(cql_query_str: &str) -> Self { debug!("parse_from_string: {}", cql_query_str); let ast = CassandraAST::new(cql_query_str); + CQL { statements: ast .statements .iter() - .map(|stmt| CQLStatement { - has_error: stmt.0, - statement: stmt.1.clone(), + .map(|(_, stmt)| match (ast.has_error(), stmt) { + (true, CassandraStatement::Unknown(_)) => CQLStatement { + statement: stmt.clone(), + }, + (true, _) => CQLStatement { + statement: CassandraStatement::Unknown(stmt.to_string()), + }, + (false, _) => CQLStatement { + statement: stmt.clone(), + }, }) .collect(), } @@ -901,6 +904,7 @@ pub struct CassandraBatch { #[cfg(test)] mod test { use crate::frame::CQL; + use cql3_parser::cassandra_statement::CassandraStatement; #[test] fn cql_round_trip_test() { @@ -924,7 +928,16 @@ mod test { let cql = CQL::parse_from_string(query); assert_eq!(3, cql.statements.len()); - assert!(!cql.has_error()); + for stmt in cql.statements { + if let CassandraStatement::Insert(_x) = stmt.statement { + // do nothing + } else { + panic!( + "{:?} should have been CassandraStatement::Insert", + stmt.statement + ); + } + } } #[test] @@ -935,9 +948,15 @@ mod test { let cql = CQL::parse_from_string(query); assert_eq!(3, cql.statements.len()); - assert!(cql.has_error()); - assert!(!cql.statements[0].has_error); - assert!(cql.statements[1].has_error); - assert!(!cql.statements[2].has_error); + for stmt in cql.statements { + if let CassandraStatement::Unknown(_x) = stmt.statement { + // do nothing + } else { + panic!( + "{:?} should have been CassandraStatement::Unknown", + stmt.statement + ); + } + } } } diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index bcd612243..7c92d5bee 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -455,9 +455,6 @@ impl SimpleRedisCache { /// * fn is_cacheable(cql_statement: &CQLStatement) -> CacheableState { // check issues common to all cql_statements - if cql_statement.has_error { - return CacheableState::Skip("CQL statement has error".into()); - } if let Some(table_name) = CQLStatement::get_table_name(&cql_statement.statement) { let has_params = CQLStatement::has_params(&cql_statement.statement); @@ -791,7 +788,6 @@ mod test { fn build_query(query_string: &str) -> CassandraStatement { let cql = CQL::parse_from_string(query_string); - assert!(!cql.has_error()); cql.statements[0].statement.clone() } From 9153aeef985a62428a03454a8b86f3338d1dd9fe Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Wed, 11 May 2022 13:57:34 +0100 Subject: [PATCH 55/60] first fix of parse_from_string --- shotover-proxy/src/frame/cassandra.rs | 120 ++++++++++++-------------- 1 file changed, 54 insertions(+), 66 deletions(-) diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index bef27fc79..287554e93 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -28,7 +28,6 @@ use cql3_parser::update::AssignmentOperator; use itertools::Itertools; use nonzero_ext::nonzero; use sodiumoxide::hex; -use std::fmt::{Display, Formatter}; use std::io::Cursor; use std::net::IpAddr; use std::num::NonZeroU32; @@ -275,7 +274,7 @@ impl CassandraFrame { // set to lowest type let mut result = QueryType::PubSubMessage; for cql_statement in &cql.statements { - result = match (cql_statement.get_query_type(), &result) { + result = match (CQLStatement::get_query_type(cql_statement), &result) { (QueryType::ReadWrite, _) => QueryType::ReadWrite, (QueryType::Write, QueryType::Read | QueryType::ReadWrite) => { QueryType::ReadWrite @@ -362,24 +361,12 @@ impl CassandraOperation { if let CassandraOperation::Query { query: cql, .. } = self { cql.statements .iter_mut() - .map(|stmt| &mut stmt.statement) .collect() } else { Vec::<&mut CassandraStatement>::new() } } - /// Return all queries contained within CassandaOperation::Query and CassandraOperation::Batch - /// - /// TODO: This will return a custom iterator type when BATCH support is added - pub fn get_cql_statements(&mut self) -> Vec<&mut CQLStatement> { - if let CassandraOperation::Query { query: cql, .. } = self { - cql.statements.iter_mut().collect() - } else { - vec![] - } - } - fn to_direction(&self) -> Direction { match self { CassandraOperation::Query { .. } => Direction::Request, @@ -506,12 +493,11 @@ impl CassandraOperation { #[derive(PartialEq, Debug, Clone)] pub struct CQLStatement { - pub statement: CassandraStatement, } impl CQLStatement { - pub fn is_begin_batch(&self) -> bool { - match &self.statement { + pub fn is_begin_batch( statement : &CassandraStatement ) -> bool { + match statement { CassandraStatement::Delete(delete) => delete.begin_batch.is_some(), CassandraStatement::Insert(insert) => insert.begin_batch.is_some(), CassandraStatement::Update(update) => update.begin_batch.is_some(), @@ -520,15 +506,15 @@ impl CQLStatement { } /// returns the query type for the current statement. - pub fn get_query_type(&self) -> QueryType { - /* - Read, - Write, - ReadWrite, - SchemaChange, - PubSubMessage, - */ - match &self.statement { + pub fn get_query_type(statement : &CassandraStatement) -> QueryType { + /* the query types in descending order, are: + ReadWrite + Write + Read + SchemaChange + PubSubMessage + */ + match statement { CassandraStatement::AlterKeyspace(_) => QueryType::SchemaChange, CassandraStatement::AlterMaterializedView(_) => QueryType::SchemaChange, CassandraStatement::AlterRole(_) => QueryType::SchemaChange, @@ -668,22 +654,16 @@ impl CQLStatement { } } -impl Display for CQLStatement { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - self.statement.fmt(f) - } -} - #[derive(PartialEq, Debug, Clone)] pub struct CQL { - pub statements: Vec, + pub statements: Vec, } impl CQL { fn to_query_string(&self) -> String { self.statements .iter() - .map(|c| c.statement.to_string()) + .map(|c| c.to_string()) .join("; ") } @@ -693,23 +673,24 @@ impl CQL { debug!("parse_from_string: {}", cql_query_str); let ast = CassandraAST::new(cql_query_str); - CQL { - statements: ast - .statements - .iter() - .map(|(_, stmt)| match (ast.has_error(), stmt) { - (true, CassandraStatement::Unknown(_)) => CQLStatement { - statement: stmt.clone(), - }, - (true, _) => CQLStatement { - statement: CassandraStatement::Unknown(stmt.to_string()), - }, - (false, _) => CQLStatement { - statement: stmt.clone(), + let error = ast.has_error(); + let statements = ast + .statements + .into_iter() + .map(|(_, statement, start, stop)| { + match (error, statement) { + (true, statement @ CassandraStatement::Unknown(_)) => statement, + (true, _) => { + match String::from_utf8(Vec::from(&cql_query_str[start..stop])) { + Ok(str) => CassandraStatement::Unknown(str), + Err(_) => CassandraStatement::Unknown(cql_query_str.to_string()) + } }, - }) - .collect(), - } + (false, statement) => statement, + } + }) + .collect(); + CQL { statements } } } @@ -929,12 +910,12 @@ mod test { let cql = CQL::parse_from_string(query); assert_eq!(3, cql.statements.len()); for stmt in cql.statements { - if let CassandraStatement::Insert(_x) = stmt.statement { + if let CassandraStatement::Insert(_x) = stmt { // do nothing } else { panic!( "{:?} should have been CassandraStatement::Insert", - stmt.statement + stmt ); } } @@ -942,21 +923,28 @@ mod test { #[test] fn cql_bad_statement_test() { - let query = r#"INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (1, 11, 'foo'); - INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) (2, 12, 'bar'); - INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (3, 13, 'baz');"#; - - let cql = CQL::parse_from_string(query); - assert_eq!(3, cql.statements.len()); - for stmt in cql.statements { - if let CassandraStatement::Unknown(_x) = stmt.statement { - // do nothing - } else { - panic!( - "{:?} should have been CassandraStatement::Unknown", - stmt.statement - ); + let query = [ + "INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) (2, 12, 'bar');", + r#"BEGIN BATCH INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (3, 13, 'baz'); + INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) (2, 12, 'bar'); + EXECUTE BATCH"#, + r#"INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (3, 13, 'baz'); + INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) (2, 12, 'bar');"#]; + let expected = [ + "INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) (2, 12, 'bar');", + r#"BEGIN BATCH INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (3, 13, 'baz'); INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) (2, 12, 'bar'); + EXECUTE BATCH"#, + "INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (3, 13, 'baz'); INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) (2, 12, 'bar');"]; + for idx in 0..query.len() { + let cql = CQL::parse_from_string(query[idx]); + let result = cql.to_query_string(); + for stmt in cql.statements { + match stmt { + CassandraStatement::Unknown(_) => {}, + _ => panic!( "Should be Unknown type"), + } } + assert_eq!(expected[idx], result, "failed at test {}", idx); } } } From 1cc13add46b779f4da17380b4c87db4f7d201bfe Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Wed, 11 May 2022 14:09:20 +0100 Subject: [PATCH 56/60] Made CQL.statments Vec Cleaned up parsing of statements with errors. --- Cargo.lock | 3 +- shotover-proxy/Cargo.toml | 2 +- shotover-proxy/src/frame/cassandra.rs | 41 ++++++------------- .../src/transforms/cassandra/peers_rewrite.rs | 5 +-- shotover-proxy/src/transforms/protect/mod.rs | 12 +++--- shotover-proxy/src/transforms/redis/cache.rs | 36 ++++++++-------- 6 files changed, 39 insertions(+), 60 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bc53a48d8..526cf2405 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -447,7 +447,7 @@ dependencies = [ [[package]] name = "cql3-parser" version = "0.1.1" -source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=create_FQName#932855cda5ba64d85068396b67be51afcca0120e" +source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=verify_error#e87dcc207247e735b6942c20d52cbc1d27bcb28a" dependencies = [ "bigdecimal", "bytes", @@ -455,7 +455,6 @@ dependencies = [ "itertools", "num", "regex", - "serde", "tree-sitter", "tree-sitter-cql", "uuid 1.0.0", diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml index a352e29a4..3b7afbee9 100644 --- a/shotover-proxy/Cargo.toml +++ b/shotover-proxy/Cargo.toml @@ -43,7 +43,7 @@ thiserror = "1.0" anyhow = "1.0.31" # Parsers -cql3-parser = { git = "https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git", branch="create_FQName" } +cql3-parser = { git = "https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git", branch="verify_error" } serde = { version = "1.0.111", features = ["derive"] } serde_json = "1.0" serde_yaml = "0.8.21" diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index 287554e93..20344dff6 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -673,23 +673,13 @@ impl CQL { debug!("parse_from_string: {}", cql_query_str); let ast = CassandraAST::new(cql_query_str); - let error = ast.has_error(); - let statements = ast - .statements - .into_iter() - .map(|(_, statement, start, stop)| { - match (error, statement) { - (true, statement @ CassandraStatement::Unknown(_)) => statement, - (true, _) => { - match String::from_utf8(Vec::from(&cql_query_str[start..stop])) { - Ok(str) => CassandraStatement::Unknown(str), - Err(_) => CassandraStatement::Unknown(cql_query_str.to_string()) - } - }, - (false, statement) => statement, - } - }) - .collect(); + let statements = if ast.has_error() { + vec![ CassandraStatement::Unknown( cql_query_str.to_string() )] + } else { + ast.statements.into_iter() + .map( |(_err, statement, _start, _end)| statement ) + .collect() + }; CQL { statements } } } @@ -930,21 +920,16 @@ mod test { EXECUTE BATCH"#, r#"INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (3, 13, 'baz'); INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) (2, 12, 'bar');"#]; - let expected = [ - "INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) (2, 12, 'bar');", - r#"BEGIN BATCH INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (3, 13, 'baz'); INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) (2, 12, 'bar'); - EXECUTE BATCH"#, - "INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (3, 13, 'baz'); INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) (2, 12, 'bar');"]; for idx in 0..query.len() { let cql = CQL::parse_from_string(query[idx]); let result = cql.to_query_string(); - for stmt in cql.statements { - match stmt { - CassandraStatement::Unknown(_) => {}, - _ => panic!( "Should be Unknown type"), - } + assert_eq!( 1, cql.statements.len() ); + if let CassandraStatement::Unknown( txt ) = &cql.statements[0] { + assert_eq!(query[idx], txt, "failed at test {}", idx); + } else { + panic!( "Should be Unknown type, failed at test {}", idx); } - assert_eq!(expected[idx], result, "failed at test {}", idx); + assert_eq!(query[idx], result, "failed at test {}", idx); } } } diff --git a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs index 895b2aba6..4a9f30271 100644 --- a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs @@ -75,9 +75,8 @@ fn extract_native_port_column(peer_table: &FQName, message: &mut Message) -> Vec let mut result: Vec = vec![]; if let Some(Frame::Cassandra(cassandra)) = message.frame() { if let CassandraOperation::Query { query, .. } = &cassandra.operation { - for cql_statement in &query.statements { - let statement = &cql_statement.statement; - if let CassandraStatement::Select(select) = &statement { + for statement in &query.statements { + if let CassandraStatement::Select(select) = statement { if peer_table == &select.table_name { for select_element in &select.columns { match select_element { diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index bec7f12cc..916a8fb86 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -327,8 +327,7 @@ impl Transform for Protect { .. })) = message.frame() { - for cql_statement in &mut query.statements { - let statement = &mut cql_statement.statement; + for statement in &mut query.statements { data_changed |= self.encrypt_columns(statement).await.unwrap(); if data_changed { debug!("statement changed to {}", statement); @@ -359,8 +358,7 @@ impl Transform for Protect { .. })) = response.frame() { - for cql_statement in &mut query.statements { - let statement = &mut cql_statement.statement; + for statement in &mut query.statements { if let Some(table_name) = CQLStatement::get_table_name(statement) { if let Some(columns) = self.get_protected_columns(table_name) { if let CassandraStatement::Select(select) = &statement { @@ -470,7 +468,7 @@ mod test { // test insert change is reversed on select let stmt_txt = "insert into test_table (col1, col2) VALUES ('Hello World', 'i am clean')"; let mut cql = CQL::parse_from_string(stmt_txt); - let statement = &mut cql.statements[0].statement; + let statement = &mut cql.statements[0]; let data_changed = protect.encrypt_columns(statement).await.unwrap(); assert!(data_changed); @@ -490,10 +488,10 @@ mod test { let stmt_txt = "select col1 from test_table where col2='i am clean'"; let cql = CQL::parse_from_string(stmt_txt); - let statement = &cql.statements[0].statement; + let statement = &cql.statements[0]; if let CassandraStatement::Select(select) = statement { - let result = protect.process_select(select, &cols, &mut rows).await; + let result = protect.process_select(&select, &cols, &mut rows).await; assert!(result.unwrap()); assert_eq!(&msg_value, &rows[0][0]); } diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index 7c92d5bee..da359bc68 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -124,10 +124,9 @@ impl SimpleRedisCache { for cass_request in cassandra_messages { match &mut cass_request.frame() { Some(Frame::Cassandra(frame)) => { - for cql_statement in frame.operation.get_cql_statements() { - let mut state = is_cacheable(cql_statement); + for statement in frame.operation.queries() { + let mut state = is_cacheable(statement); if let CacheableState::Read(table_name) = &mut state { - let statement = &cql_statement.statement; debug!("build_cache_query processing cacheable state"); if let Some(table_cache_schema) = self.caching_schema.get(table_name.as_str()) @@ -166,7 +165,7 @@ impl SimpleRedisCache { } else { state = CacheableState::Skip(format!( "{} is not a readable query", - cql_statement + statement )); } @@ -331,11 +330,11 @@ impl SimpleRedisCache { /// clear the cache for the single row specified by the redis_key fn clear_row_cache( &mut self, - cql_statement: &CQLStatement, + statement: &CassandraStatement, table_cache_schema: &TableCacheSchema, ) -> Option { // TODO is it possible to return the future and process in parallel? - let statement = &cql_statement.statement; + if let Ok((redis_key, _hash_key)) = build_redis_key_from_cql3(statement, table_cache_schema) { let commands_buffer: Vec = vec![ @@ -373,7 +372,7 @@ impl SimpleRedisCache { let result_messages = &mut message_wrapper.call_next_transform().await?; if orig_cql.is_some() { let mut cache_messages: Vec = vec![]; - for (response, cql_statement) in result_messages + for (response, statement) in result_messages .iter_mut() .zip(orig_cql.unwrap().statements.iter()) { @@ -382,12 +381,12 @@ impl SimpleRedisCache { .. })) = response.frame() { - match is_cacheable(cql_statement) { + match is_cacheable(statement) { CacheableState::Update(table_name) | CacheableState::Delete(table_name) => { if let Some(table_cache_schema) = self.caching_schema.get(&table_name) { let table_schema = table_cache_schema.clone(); if let Some(fut_message) = - self.clear_row_cache(cql_statement, &table_schema) + self.clear_row_cache(statement, &table_schema) { cache_messages.push(fut_message); } @@ -400,7 +399,6 @@ impl SimpleRedisCache { self.clear_table_cache(); } CacheableState::Read(table_name) => { - let statement = &cql_statement.statement; if let Some(table_cache_schema) = self.caching_schema.get(table_name.as_str()) { @@ -453,12 +451,12 @@ impl SimpleRedisCache { /// * must specify table name /// * must not contain a parsing error /// * -fn is_cacheable(cql_statement: &CQLStatement) -> CacheableState { - // check issues common to all cql_statements - if let Some(table_name) = CQLStatement::get_table_name(&cql_statement.statement) { - let has_params = CQLStatement::has_params(&cql_statement.statement); +fn is_cacheable(statement: &CassandraStatement) -> CacheableState { + // check issues common to all CassandraStatements + if let Some(table_name) = CQLStatement::get_table_name(statement) { + let has_params = CQLStatement::has_params(statement); - match &cql_statement.statement { + match statement { CassandraStatement::Select(select) => { if has_params { CacheableState::Delete(table_name.into()) @@ -714,9 +712,9 @@ impl Transform for SimpleRedisCache { .. })) = m.frame() { - for cql_statement in &query.statements { - debug!("cache transform processing {}", cql_statement); - match cql_statement.get_query_type() { + for statement in &query.statements { + debug!("cache transform processing {}", statement); + match CQLStatement::get_query_type(statement) { QueryType::Read => {} QueryType::Write => read_cache = false, QueryType::ReadWrite => read_cache = false, @@ -788,7 +786,7 @@ mod test { fn build_query(query_string: &str) -> CassandraStatement { let cql = CQL::parse_from_string(query_string); - cql.statements[0].statement.clone() + cql.statements[0].clone() } #[test] From c89bb1691834848609409e6213a0453faa950c4f Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Thu, 12 May 2022 13:13:47 +0100 Subject: [PATCH 57/60] changes as per review --- Cargo.lock | 3 +- shotover-proxy/src/frame/cassandra.rs | 394 +++++++++++++++---- shotover-proxy/src/message/mod.rs | 4 +- shotover-proxy/src/transforms/protect/mod.rs | 2 +- 4 files changed, 313 insertions(+), 90 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 526cf2405..91879f0b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -447,7 +447,7 @@ dependencies = [ [[package]] name = "cql3-parser" version = "0.1.1" -source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=verify_error#e87dcc207247e735b6942c20d52cbc1d27bcb28a" +source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=verify_error#5396f01195161f23bb272a4625a6946a1e1c0cfc" dependencies = [ "bigdecimal", "bytes", @@ -455,6 +455,7 @@ dependencies = [ "itertools", "num", "regex", + "serde", "tree-sitter", "tree-sitter-cql", "uuid 1.0.0", diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index 20344dff6..d68d8cf22 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -359,9 +359,7 @@ impl CassandraOperation { /// TODO: This will return a custom iterator type when BATCH support is added pub fn queries(&mut self) -> Vec<&mut CassandraStatement> { if let CassandraOperation::Query { query: cql, .. } = self { - cql.statements - .iter_mut() - .collect() + cql.statements.iter_mut().collect() } else { Vec::<&mut CassandraStatement>::new() } @@ -492,11 +490,10 @@ impl CassandraOperation { } #[derive(PartialEq, Debug, Clone)] -pub struct CQLStatement { -} +pub struct CQLStatement {} impl CQLStatement { - pub fn is_begin_batch( statement : &CassandraStatement ) -> bool { + pub fn is_begin_batch(statement: &CassandraStatement) -> bool { match statement { CassandraStatement::Delete(delete) => delete.begin_batch.is_some(), CassandraStatement::Insert(insert) => insert.begin_batch.is_some(), @@ -506,7 +503,7 @@ impl CQLStatement { } /// returns the query type for the current statement. - pub fn get_query_type(statement : &CassandraStatement) -> QueryType { + pub fn get_query_type(statement: &CassandraStatement) -> QueryType { /* the query types in descending order, are: ReadWrite Write @@ -661,10 +658,7 @@ pub struct CQL { impl CQL { fn to_query_string(&self) -> String { - self.statements - .iter() - .map(|c| c.to_string()) - .join("; ") + self.statements.iter().map(|c| c.to_string()).join("; ") } /// the CassandraAST handles multiple queries in a string separated by semi-colons: `;` however @@ -674,10 +668,11 @@ impl CQL { let ast = CassandraAST::new(cql_query_str); let statements = if ast.has_error() { - vec![ CassandraStatement::Unknown( cql_query_str.to_string() )] + vec![CassandraStatement::Unknown(cql_query_str.to_string())] } else { - ast.statements.into_iter() - .map( |(_err, statement, _start, _end)| statement ) + ast.statements + .into_iter() + .map(|parsed_statement| parsed_statement.statement) .collect() }; CQL { statements } @@ -685,83 +680,72 @@ impl CQL { } pub trait ToCassandraType { - fn from_string_value(value: &str) -> Option; - fn as_cassandra_type(&self) -> Option; + fn as_cassandra_type(&self) -> CassandraType; } impl ToCassandraType for Operand { - fn from_string_value(value: &str) -> Option { - // check for string types - if value.starts_with('\'') || value.starts_with("$$") { - let mut chars = value.chars(); - chars.next(); - chars.next_back(); - if value.starts_with('$') { - chars.next(); - chars.next_back(); + fn as_cassandra_type(&self) -> CassandraType { + // function to convert string to CassandraType + let from_string_value = |value: &str| { + // check for string types + if value.starts_with('\'') || value.starts_with("$$") { + /* to convert to a VarChar type we have to strip the delimiters off the front and back + of the string. Soe remove one char (front and back) in the case of `'` and two in the case of `$$` + */ + CassandraType::Varchar(Operand::unescape(value)) + } else if value.starts_with("0X") || value.starts_with("0x") { + hex::decode(&value[2..]) + .map(|x| CassandraType::Blob(Blob::from(x))) + .unwrap_or(CassandraType::Null) + } else if let Ok(n) = i64::from_str(value) { + CassandraType::Bigint(n) + } else if let Ok(n) = f64::from_str(value) { + CassandraType::Double(n) + } else if let Ok(uuid) = Uuid::parse_str(value) { + CassandraType::Uuid(uuid) + } else if let Ok(ipaddr) = IpAddr::from_str(value) { + CassandraType::Inet(ipaddr) + } else { + CassandraType::Null } - Some(CassandraType::Varchar(chars.as_str().to_string())) - } else if value.starts_with("0X") || value.starts_with("0x") { - let mut chars = value.chars(); - chars.next(); - chars.next(); - let bytes = hex::decode(chars.as_str()).unwrap(); - Some(CassandraType::Blob(Blob::from(bytes))) - } else if let Ok(n) = i64::from_str(value) { - Some(CassandraType::Bigint(n)) - } else if let Ok(n) = f64::from_str(value) { - Some(CassandraType::Double(n)) - } else if let Ok(uuid) = Uuid::parse_str(value) { - Some(CassandraType::Uuid(uuid)) - } else if let Ok(ipaddr) = IpAddr::from_str(value) { - Some(CassandraType::Inet(ipaddr)) - } else { - None - } - } - - fn as_cassandra_type(&self) -> Option { + }; match self { - Operand::Const(value) => Operand::from_string_value(value), - Operand::Map(values) => Some(CassandraType::Map( - values + Operand::Const(value) => from_string_value(value), + Operand::Map(values) => { + let mapping = values .iter() - .map(|(key, value)| { - ( - Operand::from_string_value(key).unwrap(), - Operand::from_string_value(value).unwrap(), - ) - }) - .collect(), - )), - Operand::Set(values) => Some(CassandraType::Set( + .map(|(key, value)| (from_string_value(key), from_string_value(value))) + .collect(); + CassandraType::Map(mapping) + } + Operand::Set(values) => CassandraType::Set( values .iter() - .filter_map(|value| Operand::from_string_value(value)) + .map(|value| from_string_value(value)) .collect(), - )), - Operand::List(values) => Some(CassandraType::List( + ), + Operand::List(values) => CassandraType::List( values .iter() - .filter_map(|value| Operand::from_string_value(value)) + .map(|value| from_string_value(value)) .collect(), - )), - Operand::Tuple(values) => Some(CassandraType::Tuple( + ), + Operand::Tuple(values) => CassandraType::Tuple( values .iter() - .filter_map(|value| value.as_cassandra_type()) + .map(|value| value.as_cassandra_type()) .collect(), - )), - Operand::Column(value) => Some(CassandraType::Ascii(value.to_string())), - Operand::Func(value) => Some(CassandraType::Ascii(value.to_string())), - Operand::Null => Some(CassandraType::Null), - Operand::Param(_) => None, - Operand::Collection(values) => Some(CassandraType::List( + ), + Operand::Column(value) => CassandraType::Ascii(value.to_string()), + Operand::Func(value) => CassandraType::Ascii(value.to_string()), + Operand::Null => CassandraType::Null, + Operand::Param(_) => CassandraType::Null, + Operand::Collection(values) => CassandraType::List( values .iter() - .filter_map(|value| value.as_cassandra_type()) + .map(|value| value.as_cassandra_type()) .collect(), - )), + ), } } } @@ -874,8 +858,15 @@ pub struct CassandraBatch { #[cfg(test)] mod test { + use crate::frame::cassandra::ToCassandraType; use crate::frame::CQL; + use cassandra_protocol::types::cassandra_type::CassandraType; + use cassandra_protocol::types::prelude::Blob; use cql3_parser::cassandra_statement::CassandraStatement; + use cql3_parser::common::Operand; + use std::net::IpAddr; + use std::str::FromStr; + use uuid::Uuid; #[test] fn cql_round_trip_test() { @@ -903,33 +894,266 @@ mod test { if let CassandraStatement::Insert(_x) = stmt { // do nothing } else { - panic!( - "{:?} should have been CassandraStatement::Insert", - stmt - ); + panic!("{:?} should have been CassandraStatement::Insert", stmt); } } } #[test] fn cql_bad_statement_test() { - let query = [ + let queries = [ "INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) (2, 12, 'bar');", r#"BEGIN BATCH INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (3, 13, 'baz'); INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) (2, 12, 'bar'); EXECUTE BATCH"#, r#"INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) VALUES (3, 13, 'baz'); INSERT INTO test_cache_keyspace_batch_insert.test_table (id, x, name) (2, 12, 'bar');"#]; - for idx in 0..query.len() { - let cql = CQL::parse_from_string(query[idx]); + for query in queries { + let cql = CQL::parse_from_string(query); let result = cql.to_query_string(); - assert_eq!( 1, cql.statements.len() ); - if let CassandraStatement::Unknown( txt ) = &cql.statements[0] { - assert_eq!(query[idx], txt, "failed at test {}", idx); + assert_eq!(1, cql.statements.len()); + if let CassandraStatement::Unknown(txt) = &cql.statements[0] { + assert_eq!(query, txt, "failed at test {}", query); } else { - panic!( "Should be Unknown type, failed at test {}", idx); + panic!("Should be Unknown type, failed at test {}", query); } - assert_eq!(query[idx], result, "failed at test {}", idx); + assert_eq!(query, result); + } + } + + #[test] + pub fn test_to_cassandra_type_for_const_operand() { + assert_eq!( + CassandraType::Bigint(55), + Operand::Const("55".to_string()).as_cassandra_type() + ); + assert_eq!( + CassandraType::Double(5.5), + Operand::Const("5.5".to_string()).as_cassandra_type() + ); + let uuid = Uuid::parse_str("123e4567-e89b-12d3-a456-426655440000").unwrap(); + assert_eq!( + CassandraType::Uuid(uuid), + Operand::Const("123e4567-e89b-12d3-a456-426655440000".to_string()).as_cassandra_type() + ); + let ipaddr = IpAddr::from_str("192.168.0.1").unwrap(); + assert_eq!( + CassandraType::Inet(ipaddr), + Operand::Const("192.168.0.1".to_string()).as_cassandra_type() + ); + let ipaddr = IpAddr::from_str("2001:0db8:85a3:0000:0000:8a2e:0370:7334").unwrap(); + assert_eq!( + CassandraType::Inet(ipaddr), + Operand::Const("2001:0db8:85a3:0000:0000:8a2e:0370:7334".to_string()) + .as_cassandra_type() + ); + assert_eq!( + CassandraType::Blob(Blob::from(vec![255_u8, 234_u8, 1_u8, 13_u8])), + Operand::Const("0xFFEA010D".to_string()).as_cassandra_type() + ); + let tests = [ + ( + "'Women''s Tour of New Zealand'", + "Women's Tour of New Zealand", + ), + ( + "$$Women's Tour of New Zealand$$", + "Women's Tour of New Zealand", + ), + ( + "$$Women''s Tour of New Zealand$$", + "Women''s Tour of New Zealand", + ), + ]; + for (txt, expected) in tests { + assert_eq!( + CassandraType::Varchar(expected.to_string()), + Operand::Const(txt.to_string()).as_cassandra_type() + ); } + assert_eq!( + CassandraType::Null, + Operand::Const("not a valid const".to_string()).as_cassandra_type() + ); + assert_eq!( + CassandraType::Null, + Operand::Const("0xnot a hex".to_string()).as_cassandra_type() + ); + } + + #[test] + pub fn test_to_cassandra_type_for_string_collection_operands() { + let args = vec![ + "55".to_string(), + "5.5".to_string(), + "123e4567-e89b-12d3-a456-426655440000".to_string(), + "192.168.0.1".to_string(), + "2001:0db8:85a3:0000:0000:8a2e:0370:7334".to_string(), + "0xFFEA010D".to_string(), + "'Women''s Tour of New Zealand'".to_string(), + "$$Women's Tour of New Zealand$$".to_string(), + "$$Women''s Tour of New Zealand$$".to_string(), + "invalid text".to_string(), + "0xinvalid hex".to_string(), + ]; + + let expected = vec![ + CassandraType::Bigint(55), + CassandraType::Double(5.5), + CassandraType::Uuid(Uuid::parse_str("123e4567-e89b-12d3-a456-426655440000").unwrap()), + CassandraType::Inet(IpAddr::from_str("192.168.0.1").unwrap()), + CassandraType::Inet( + IpAddr::from_str("2001:0db8:85a3:0000:0000:8a2e:0370:7334").unwrap(), + ), + CassandraType::Blob(Blob::from(vec![255_u8, 234_u8, 1_u8, 13_u8])), + CassandraType::Varchar("Women's Tour of New Zealand".to_string()), + CassandraType::Varchar("Women's Tour of New Zealand".to_string()), + CassandraType::Varchar("Women''s Tour of New Zealand".to_string()), + CassandraType::Null, + CassandraType::Null, + ]; + + assert_eq!( + CassandraType::List(expected.clone()), + Operand::List(args.clone()).as_cassandra_type() + ); + assert_eq!( + CassandraType::Set(expected), + Operand::Set(args).as_cassandra_type() + ); + } + + #[test] + pub fn test_to_cassandra_type_for_map_operand() { + let args = vec![ + ("1".to_string(), "55".to_string()), + ("2".to_string(), "5.5".to_string()), + ( + "3".to_string(), + "123e4567-e89b-12d3-a456-426655440000".to_string(), + ), + ("4".to_string(), "192.168.0.1".to_string()), + ( + "5".to_string(), + "2001:0db8:85a3:0000:0000:8a2e:0370:7334".to_string(), + ), + ("6".to_string(), "0xFFEA010D".to_string()), + ( + "7".to_string(), + "'Women''s Tour of New Zealand'".to_string(), + ), + ( + "8".to_string(), + "$$Women's Tour of New Zealand$$".to_string(), + ), + ( + "9".to_string(), + "$$Women''s Tour of New Zealand$$".to_string(), + ), + ("'A'".to_string(), "invalid text".to_string()), + ("'B'".to_string(), "0xinvalid hex".to_string()), + ]; + let expected = vec![ + (CassandraType::Bigint(1), CassandraType::Bigint(55)), + (CassandraType::Bigint(2), CassandraType::Double(5.5)), + ( + CassandraType::Bigint(3), + CassandraType::Uuid( + Uuid::parse_str("123e4567-e89b-12d3-a456-426655440000").unwrap(), + ), + ), + ( + CassandraType::Bigint(4), + CassandraType::Inet(IpAddr::from_str("192.168.0.1").unwrap()), + ), + ( + CassandraType::Bigint(5), + CassandraType::Inet( + IpAddr::from_str("2001:0db8:85a3:0000:0000:8a2e:0370:7334").unwrap(), + ), + ), + ( + CassandraType::Bigint(6), + CassandraType::Blob(Blob::from(vec![255_u8, 234_u8, 1_u8, 13_u8])), + ), + ( + CassandraType::Bigint(7), + CassandraType::Varchar("Women's Tour of New Zealand".to_string()), + ), + ( + CassandraType::Bigint(8), + CassandraType::Varchar("Women's Tour of New Zealand".to_string()), + ), + ( + CassandraType::Bigint(9), + CassandraType::Varchar("Women''s Tour of New Zealand".to_string()), + ), + (CassandraType::Varchar("A".to_string()), CassandraType::Null), + (CassandraType::Varchar("B".to_string()), CassandraType::Null), + ]; + + assert_eq!( + CassandraType::Map(expected), + Operand::Map(args).as_cassandra_type() + ) + } + + #[test] + pub fn test_to_cassandra_type_for_collection_operands() { + let args = vec![ + Operand::Const("55".to_string()), + Operand::Const("5.5".to_string()), + Operand::Const("123e4567-e89b-12d3-a456-426655440000".to_string()), + Operand::Const("192.168.0.1".to_string()), + Operand::Const("2001:0db8:85a3:0000:0000:8a2e:0370:7334".to_string()), + Operand::Const("0xFFEA010D".to_string()), + Operand::Const("'Women''s Tour of New Zealand'".to_string()), + Operand::Const("$$Women's Tour of New Zealand$$".to_string()), + Operand::Const("$$Women''s Tour of New Zealand$$".to_string()), + Operand::Const("invalid text".to_string()), + Operand::Const("0xinvalid hex".to_string()), + ]; + + let expected = vec![ + CassandraType::Bigint(55), + CassandraType::Double(5.5), + CassandraType::Uuid(Uuid::parse_str("123e4567-e89b-12d3-a456-426655440000").unwrap()), + CassandraType::Inet(IpAddr::from_str("192.168.0.1").unwrap()), + CassandraType::Inet( + IpAddr::from_str("2001:0db8:85a3:0000:0000:8a2e:0370:7334").unwrap(), + ), + CassandraType::Blob(Blob::from(vec![255_u8, 234_u8, 1_u8, 13_u8])), + CassandraType::Varchar("Women's Tour of New Zealand".to_string()), + CassandraType::Varchar("Women's Tour of New Zealand".to_string()), + CassandraType::Varchar("Women''s Tour of New Zealand".to_string()), + CassandraType::Null, + CassandraType::Null, + ]; + + assert_eq!( + CassandraType::Tuple(expected.clone()), + Operand::Tuple(args.clone()).as_cassandra_type() + ); + assert_eq!( + CassandraType::List(expected), + Operand::Collection(args).as_cassandra_type() + ); + } + + #[test] + pub fn test_to_cassandra_type_for_misc_operands() { + assert_eq!( + CassandraType::Ascii("Hello".to_string()), + Operand::Column("Hello".to_string()).as_cassandra_type() + ); + assert_eq!( + CassandraType::Ascii("Hello".to_string()), + Operand::Func("Hello".to_string()).as_cassandra_type() + ); + assert_eq!(CassandraType::Null, Operand::Null.as_cassandra_type()); + assert_eq!( + CassandraType::Null, + Operand::Param("Hello".to_string()).as_cassandra_type() + ); } } diff --git a/shotover-proxy/src/message/mod.rs b/shotover-proxy/src/message/mod.rs index 383ee4a92..8e3842297 100644 --- a/shotover-proxy/src/message/mod.rs +++ b/shotover-proxy/src/message/mod.rs @@ -549,9 +549,7 @@ impl From<&MessageValue> for Operand { impl From<&Operand> for MessageValue { fn from(operand: &Operand) -> Self { - operand - .as_cassandra_type() - .map_or(MessageValue::None, MessageValue::create_element) + MessageValue::create_element(operand.as_cassandra_type()) } } diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index 916a8fb86..850d6bcea 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -491,7 +491,7 @@ mod test { let statement = &cql.statements[0]; if let CassandraStatement::Select(select) = statement { - let result = protect.process_select(&select, &cols, &mut rows).await; + let result = protect.process_select(select, &cols, &mut rows).await; assert!(result.unwrap()); assert_eq!(&msg_value, &rows[0][0]); } From 2beccfa091a5845583f526332e1f1920caf085c3 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Fri, 13 May 2022 09:47:27 +0100 Subject: [PATCH 58/60] changes as per review --- Cargo.lock | 2 +- shotover-proxy/src/frame/cassandra.rs | 44 ++-- shotover-proxy/src/transforms/protect/mod.rs | 6 +- shotover-proxy/src/transforms/redis/cache.rs | 244 +++++++++++-------- 4 files changed, 159 insertions(+), 137 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 91879f0b8..b969aeaa2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -447,7 +447,7 @@ dependencies = [ [[package]] name = "cql3-parser" version = "0.1.1" -source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=verify_error#5396f01195161f23bb272a4625a6946a1e1c0cfc" +source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=verify_error#a9063728fac864c83a9d5139406d491753aeb425" dependencies = [ "bigdecimal", "bytes", diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index d68d8cf22..22561f119 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -22,9 +22,7 @@ use cassandra_protocol::types::cassandra_type::CassandraType; use cassandra_protocol::types::{CBytes, CBytesShort, CInt, CLong}; use cql3_parser::cassandra_ast::CassandraAST; use cql3_parser::cassandra_statement::CassandraStatement; -use cql3_parser::common::{FQName, Operand, RelationElement}; -use cql3_parser::insert::InsertValues; -use cql3_parser::update::AssignmentOperator; +use cql3_parser::common::{FQName, Operand}; use itertools::Itertools; use nonzero_ext::nonzero; use sodiumoxide::hex; @@ -274,7 +272,7 @@ impl CassandraFrame { // set to lowest type let mut result = QueryType::PubSubMessage; for cql_statement in &cql.statements { - result = match (CQLStatement::get_query_type(cql_statement), &result) { + result = match (cql_statement::get_query_type(cql_statement), &result) { (QueryType::ReadWrite, _) => QueryType::ReadWrite, (QueryType::Write, QueryType::Read | QueryType::ReadWrite) => { QueryType::ReadWrite @@ -305,7 +303,7 @@ impl CassandraFrame { self.operation .queries() .into_iter() - .filter_map(|stmt| CQLStatement::get_table_name(stmt)) + .filter_map(|stmt| cql_statement::get_table_name(stmt)) .collect() } @@ -489,18 +487,12 @@ impl CassandraOperation { } } -#[derive(PartialEq, Debug, Clone)] -pub struct CQLStatement {} - -impl CQLStatement { - pub fn is_begin_batch(statement: &CassandraStatement) -> bool { - match statement { - CassandraStatement::Delete(delete) => delete.begin_batch.is_some(), - CassandraStatement::Insert(insert) => insert.begin_batch.is_some(), - CassandraStatement::Update(update) => update.begin_batch.is_some(), - _ => false, - } - } +pub mod cql_statement { + use crate::message::QueryType; + use cql3_parser::cassandra_statement::CassandraStatement; + use cql3_parser::common::{FQName, Operand, RelationElement}; + use cql3_parser::insert::InsertValues; + use cql3_parser::update::AssignmentOperator; /// returns the query type for the current statement. pub fn get_query_type(statement: &CassandraStatement) -> QueryType { @@ -575,7 +567,7 @@ impl CQLStatement { match operand { Operand::Tuple(vec) | Operand::Collection(vec) => { for oper in vec { - if CQLStatement::has_params_in_operand(oper) { + if has_params_in_operand(oper) { return true; } } @@ -588,7 +580,7 @@ impl CQLStatement { fn has_params_in_relation_elements(where_clause: &[RelationElement]) -> bool { for relation_idx in where_clause { - if CQLStatement::has_params_in_operand(&relation_idx.value) { + if has_params_in_operand(&relation_idx.value) { return true; } } @@ -599,10 +591,10 @@ impl CQLStatement { pub fn has_params(statement: &CassandraStatement) -> bool { match statement { CassandraStatement::Delete(delete) => { - if CQLStatement::has_params_in_relation_elements(&delete.where_clause) { + if has_params_in_relation_elements(&delete.where_clause) { return true; } - if CQLStatement::has_params_in_relation_elements(&delete.if_clause) { + if has_params_in_relation_elements(&delete.if_clause) { return true; } } @@ -616,7 +608,7 @@ impl CQLStatement { } } CassandraStatement::Select(select) => { - return CQLStatement::has_params_in_relation_elements(&select.where_clause); + return has_params_in_relation_elements(&select.where_clause); } CassandraStatement::Update(update) => { for assignment_element in &update.assignments { @@ -638,10 +630,10 @@ impl CQLStatement { } } } - if CQLStatement::has_params_in_relation_elements(&update.where_clause) { + if has_params_in_relation_elements(&update.where_clause) { return true; } - if CQLStatement::has_params_in_relation_elements(&update.if_clause) { + if has_params_in_relation_elements(&update.if_clause) { return true; } } @@ -686,7 +678,7 @@ pub trait ToCassandraType { impl ToCassandraType for Operand { fn as_cassandra_type(&self) -> CassandraType { // function to convert string to CassandraType - let from_string_value = |value: &str| { + fn from_string_value(value: &str) -> CassandraType { // check for string types if value.starts_with('\'') || value.starts_with("$$") { /* to convert to a VarChar type we have to strip the delimiters off the front and back @@ -708,7 +700,7 @@ impl ToCassandraType for Operand { } else { CassandraType::Null } - }; + } match self { Operand::Const(value) => from_string_value(value), Operand::Map(values) => { diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index 850d6bcea..5543283ca 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -1,5 +1,5 @@ use crate::error::ChainResponse; -use crate::frame::cassandra::CQLStatement; +use crate::frame::cassandra::cql_statement; use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; use crate::message::MessageValue; use crate::transforms::protect::key_management::{KeyManager, KeyManagerConfig}; @@ -86,7 +86,7 @@ impl Protect { /// * `key_id` the key within the manager to use. async fn encrypt_columns(&self, statement: &mut CassandraStatement) -> Result { let mut data_changed = false; - if let Some(table_name) = CQLStatement::get_table_name(statement) { + if let Some(table_name) = cql_statement::get_table_name(statement) { if let Some(columns) = self.get_protected_columns(table_name) { match statement { CassandraStatement::Insert(insert) => { @@ -359,7 +359,7 @@ impl Transform for Protect { })) = response.frame() { for statement in &mut query.statements { - if let Some(table_name) = CQLStatement::get_table_name(statement) { + if let Some(table_name) = cql_statement::get_table_name(statement) { if let Some(columns) = self.get_protected_columns(table_name) { if let CassandraStatement::Select(select) = &statement { invalidate_cache |= diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index da359bc68..c5830e012 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -1,6 +1,6 @@ use crate::config::topology::TopicHolder; use crate::error::ChainResponse; -use crate::frame::cassandra::CQLStatement; +use crate::frame::cassandra::cql_statement; use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame, RedisFrame, CQL}; use crate::message::{Message, Messages, QueryType}; use crate::transforms::chain::TransformChain; @@ -9,7 +9,7 @@ use crate::transforms::{ }; use anyhow::Result; use async_trait::async_trait; -use bytes::{Bytes, BytesMut}; +use bytes::Bytes; use cassandra_protocol::frame::Serialize; use cassandra_protocol::frame::Version; use cql3_parser::cassandra_statement::CassandraStatement; @@ -38,40 +38,34 @@ metadata - serialized form of the metadata from cassandra. */ enum CacheableState { - /// string is the table name - Read(String), - /// string is the table name - Update(String), - /// string is the table name - Delete(String), - /// string is the table being dropped - Drop(String), - /// string is the reason for the skip - Skip(String), - /// string is the reason for the error - Err(String), + Read { table_name: String }, + Update { table_name: String }, + Delete { table_name: String }, + Drop { table_name: String }, + Skip { reason: String }, + Err { reason: String }, } impl Display for CacheableState { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - CacheableState::Read(name) => { - write!(f, "Reading {}", name) + CacheableState::Read { table_name } => { + write!(f, "Reading {}", table_name) } - CacheableState::Update(name) => { - write!(f, "Updating {}", name) + CacheableState::Update { table_name } => { + write!(f, "Updating {}", table_name) } - CacheableState::Delete(name) => { - write!(f, "Deleting {}", name) + CacheableState::Delete { table_name } => { + write!(f, "Deleting {}", table_name) } - CacheableState::Drop(name) => { - write!(f, "Dropping {}", name) + CacheableState::Drop { table_name } => { + write!(f, "Dropping {}", table_name) } - CacheableState::Skip(txt) => { - write!(f, "Skipping due to: {}", txt) + CacheableState::Skip { reason } => { + write!(f, "Skipping due to: {}", reason) } - CacheableState::Err(txt) => { - write!(f, "Error due to: {}", txt) + CacheableState::Err { reason } => { + write!(f, "Error due to: {}", reason) } } } @@ -126,7 +120,7 @@ impl SimpleRedisCache { Some(Frame::Cassandra(frame)) => { for statement in frame.operation.queries() { let mut state = is_cacheable(statement); - if let CacheableState::Read(table_name) = &mut state { + if let CacheableState::Read { table_name } = &mut state { debug!("build_cache_query processing cacheable state"); if let Some(table_cache_schema) = self.caching_schema.get(table_name.as_str()) @@ -157,20 +151,18 @@ impl SimpleRedisCache { } } } else { - state = CacheableState::Skip(format!( - "Table {} not in caching list", - table_name - )); + state = CacheableState::Skip { + reason: format!("Table {} not in caching list", table_name), + }; } } else { - state = CacheableState::Skip(format!( - "{} is not a readable query", - statement - )); + state = CacheableState::Skip { + reason: format!("{} is not a readable query", statement), + }; } match state { - CacheableState::Err(_) | CacheableState::Skip(_) => { + CacheableState::Err { .. } | CacheableState::Skip { .. } => { debug!("build_cache_query: {}", state); return Err(state); } @@ -179,9 +171,9 @@ impl SimpleRedisCache { } } _ => { - return Err(CacheableState::Err(format!( - "cannot fetch {cass_request:?} from cache" - ))) + return Err(CacheableState::Err { + reason: format!("cannot fetch {cass_request:?} from cache"), + }) } } } @@ -206,55 +198,60 @@ impl SimpleRedisCache { for cass_request in cassandra_messages.iter_mut() { // the responses for this request let cassandra_result: Result = - if let Some(Frame::Cassandra(frame)) = &mut cass_request.frame() { + if let Some(Frame::Cassandra(frame)) = cass_request.frame() { let queries = frame.operation.queries(); if queries.len() != 1 { - Err(CacheableState::Err( - "Cacheable Cassandra query must be only one statement".into(), - )) + Err(CacheableState::Err { + reason: "Cacheable Cassandra query must be only one statement".into(), + }) } else if let Some(mut redis_response) = messages_redis_response_iter.next() { match redis_response.frame() { Some(Frame::Redis(redis_frame)) => { match redis_frame { - RedisFrame::SimpleString(_) => Err(CacheableState::Err( - "Redis returned a simple string".into(), - )), + RedisFrame::SimpleString(_) => Err(CacheableState::Err { + reason: "Redis returned a simple string".into(), + }), RedisFrame::Error(e) => { - return Err(CacheableState::Err(e.to_string())) + return Err(CacheableState::Err { + reason: e.to_string(), + }) } - RedisFrame::Integer(_) => Err(CacheableState::Err( - "Redis returned an int value".into(), - )), + RedisFrame::Integer(_) => Err(CacheableState::Err { + reason: "Redis returned an int value".into(), + }), RedisFrame::BulkString(redis_bytes) => { // Redis response contains serialized version of result struct from CassandraOperation::Result( result ) - let x = redis_bytes.iter().copied().collect_vec(); - let mut cursor = Cursor::new(x.as_slice()); + let mut cursor = Cursor::new(redis_bytes.as_ref()); let answer = CassandraResult::from_cursor(&mut cursor, Version::V4); if let Ok(result) = answer { Ok(result) } else { - Err(CacheableState::Err( - answer.err().unwrap().to_string(), - )) + Err(CacheableState::Err { + reason: answer.unwrap_err().to_string(), + }) } } - RedisFrame::Array(_) => Err(CacheableState::Err( - "Redis returned an array value".into(), - )), + RedisFrame::Array(_) => Err(CacheableState::Err { + reason: "Redis returned an array value".into(), + }), RedisFrame::Null => { self.missed_requests.increment(1); - Err(CacheableState::Skip("No cache results".into())) + Err(CacheableState::Skip { + reason: "No cache results".into(), + }) } } } - _ => Err(CacheableState::Err( - "No Redis frame in Redis response".into(), - )), + _ => Err(CacheableState::Err { + reason: "No Redis frame in Redis response".into(), + }), } } else { - Err(CacheableState::Err("Redis response was None".into())) + Err(CacheableState::Err { + reason: "Redis response was None".into(), + }) } } else { Ok(CassandraResult::Void) @@ -313,7 +310,9 @@ impl SimpleRedisCache { "clientdetailstodo".to_string(), ) .await - .map_err(|e| CacheableState::Err(format!("Redis error: {}", e)))?; + .map_err(|e| CacheableState::Err { + reason: format!("Redis error: {}", e), + })?; debug!("read_from_cache received OK from cache_chain.process_request"); self.unwrap_cache_response(messages_redis_response, cassandra_messages) @@ -370,11 +369,9 @@ impl SimpleRedisCache { }) .next(); let result_messages = &mut message_wrapper.call_next_transform().await?; - if orig_cql.is_some() { + if let Some(orig_cql) = orig_cql { let mut cache_messages: Vec = vec![]; - for (response, statement) in result_messages - .iter_mut() - .zip(orig_cql.unwrap().statements.iter()) + for (response, statement) in result_messages.iter_mut().zip(orig_cql.statements.iter()) { if let Some(Frame::Cassandra(CassandraFrame { operation: CassandraOperation::Result(result), @@ -382,7 +379,8 @@ impl SimpleRedisCache { })) = response.frame() { match is_cacheable(statement) { - CacheableState::Update(table_name) | CacheableState::Delete(table_name) => { + CacheableState::Update { table_name } + | CacheableState::Delete { table_name } => { if let Some(table_cache_schema) = self.caching_schema.get(&table_name) { let table_schema = table_cache_schema.clone(); if let Some(fut_message) = @@ -394,11 +392,11 @@ impl SimpleRedisCache { debug!("table {} is not being cached", table_name); } } - CacheableState::Drop(table_name) => { + CacheableState::Drop { table_name } => { info!("table {} dropped", table_name); self.clear_table_cache(); } - CacheableState::Read(table_name) => { + CacheableState::Read { table_name } => { if let Some(table_cache_schema) = self.caching_schema.get(table_name.as_str()) { @@ -422,7 +420,7 @@ impl SimpleRedisCache { } } } - CacheableState::Skip(_reason) | CacheableState::Err(_reason) => { + CacheableState::Skip { .. } | CacheableState::Err { .. } => { // do nothing } } @@ -453,32 +451,48 @@ impl SimpleRedisCache { /// * fn is_cacheable(statement: &CassandraStatement) -> CacheableState { // check issues common to all CassandraStatements - if let Some(table_name) = CQLStatement::get_table_name(statement) { - let has_params = CQLStatement::has_params(statement); + if let Some(table_name) = cql_statement::get_table_name(statement) { + let has_params = cql_statement::has_params(statement); match statement { CassandraStatement::Select(select) => { if has_params { - CacheableState::Delete(table_name.into()) + CacheableState::Delete { + table_name: table_name.into(), + } } else if select.filtering { - CacheableState::Skip("Can not cache with ALLOW FILTERING".into()) + CacheableState::Skip { + reason: "Can not cache with ALLOW FILTERING".into(), + } } else if select.where_clause.is_empty() { - CacheableState::Skip("Can not cache if where clause is empty".into()) + CacheableState::Skip { + reason: "Can not cache if where clause is empty".into(), + } } else { - CacheableState::Read(table_name.into()) + CacheableState::Read { + table_name: table_name.to_string(), + } } } CassandraStatement::Insert(insert) => { if has_params || insert.if_not_exists { - CacheableState::Delete(table_name.into()) + CacheableState::Delete { + table_name: table_name.into(), + } } else { - CacheableState::Update(table_name.into()) + CacheableState::Update { + table_name: table_name.into(), + } } } - CassandraStatement::DropTable(_) => CacheableState::Drop(table_name.into()), + CassandraStatement::DropTable(_) => CacheableState::Drop { + table_name: table_name.into(), + }, CassandraStatement::Update(update) => { if has_params || update.if_exists { - CacheableState::Delete(table_name.into()) + CacheableState::Delete { + table_name: table_name.into(), + } } else { for assignment_element in &update.assignments { if assignment_element.operator.is_some() { @@ -486,24 +500,34 @@ fn is_cacheable(statement: &CassandraStatement) -> CacheableState { "Clearing {} cache: {} has calculations in values", update.table_name, assignment_element.name ); - return CacheableState::Delete(table_name.into()); + return CacheableState::Delete { + table_name: table_name.into(), + }; } if assignment_element.name.idx.is_some() { debug!( "Clearing {} cache: {} is an indexed columns", update.table_name, assignment_element.name ); - return CacheableState::Delete(table_name.into()); + return CacheableState::Delete { + table_name: table_name.into(), + }; } } - CacheableState::Update(table_name.into()) + CacheableState::Update { + table_name: table_name.into(), + } } } - _ => CacheableState::Skip("Statement is not a cacheable type".into()), + _ => CacheableState::Skip { + reason: "Statement is not a cacheable type".into(), + }, } } else { - CacheableState::Skip("No table name specified".into()) + CacheableState::Skip { + reason: "No table name specified".into(), + } } } @@ -521,17 +545,21 @@ fn build_query_redis_key_from_value_map( debug!("processing partition key segment: {}", column_name); match query_values.get(column_name.as_str()) { None => { - return Err(CacheableState::Skip(format!( - "Partition key not complete. missing segment {}", - column_name - ))); + return Err(CacheableState::Skip { + reason: format!( + "Partition key not complete. missing segment {}", + column_name + ), + }); } Some(relation_elements) => { if relation_elements.len() > 1 { - return Err(CacheableState::Skip(format!( - "partition key segment {} has more than one relationship", - column_name - ))); + return Err(CacheableState::Skip { + reason: format!( + "partition key segment {} has more than one relationship", + column_name + ), + }); } debug!( "extending key with segment {} value {}", @@ -553,15 +581,17 @@ fn build_query_redis_key_from_value_map( Some(relation_elements) => { if skipping { // we skipped an earlier column so this is an error. - return Err(CacheableState::Err( - "Columns in the middle of the range key were skipped".into(), - )); + return Err(CacheableState::Err { + reason: "Columns in the middle of the range key were skipped".into(), + }); } if relation_elements.len() > 1 { - return Err(CacheableState::Skip(format!( - "partition key segment {} has more than one relationship", - column_name - ))); + return Err(CacheableState::Skip { + reason: format!( + "partition key segment {} has more than one relationship", + column_name + ), + }); } debug!( "extending key with segment {} value {}", @@ -573,7 +603,7 @@ fn build_query_redis_key_from_value_map( } } } - Ok(BytesMut::from(key.as_slice()).freeze()) + Ok(Bytes::from(key)) } /// build the redis key for the query. @@ -624,7 +654,7 @@ fn build_query_redis_hash_from_value_map( .as_str(), ); - Ok(BytesMut::from(str.as_str()).freeze()) + Ok(Bytes::from(str)) } fn populate_value_map_from_where_clause( @@ -714,7 +744,7 @@ impl Transform for SimpleRedisCache { { for statement in &query.statements { debug!("cache transform processing {}", statement); - match CQLStatement::get_query_type(statement) { + match cql_statement::get_query_type(statement) { QueryType::Read => {} QueryType::Write => read_cache = false, QueryType::ReadWrite => read_cache = false, @@ -733,12 +763,12 @@ impl Transform for SimpleRedisCache { match self.read_from_cache(message_wrapper.messages.clone()).await { Ok(cr) => return Ok(cr), Err(inner_state) => match &inner_state { - CacheableState::Skip(reason) => { + CacheableState::Skip { reason } => { info!("Cache skipped: {} ", reason); self.execute_upstream_and_process_result(message_wrapper) .await } - CacheableState::Err(reason) => { + CacheableState::Err { reason } => { error!("Cache failed: {} ", reason); message_wrapper.call_next_transform().await } From 5ea6dc40df3bcd380b77178a82e123b1fc750bc1 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Fri, 20 May 2022 07:44:20 +0100 Subject: [PATCH 59/60] switched Identifier for named items --- Cargo.lock | 9 +- shotover-proxy/Cargo.toml | 2 +- shotover-proxy/src/frame/cassandra.rs | 10 +- shotover-proxy/src/message/mod.rs | 4 +- .../src/transforms/cassandra/peers_rewrite.rs | 40 +++-- shotover-proxy/src/transforms/protect/mod.rs | 21 +-- shotover-proxy/src/transforms/redis/cache.rs | 158 +++++++++--------- 7 files changed, 128 insertions(+), 116 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b969aeaa2..062229f6b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -446,8 +446,8 @@ dependencies = [ [[package]] name = "cql3-parser" -version = "0.1.1" -source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=verify_error#a9063728fac864c83a9d5139406d491753aeb425" +version = "0.2.0" +source = "git+https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git?branch=shotover_version#7ea18c7e3021b5630fb576bb40016c62cc995b63" dependencies = [ "bigdecimal", "bytes", @@ -3084,9 +3084,8 @@ dependencies = [ [[package]] name = "tree-sitter-cql" -version = "0.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fdcd73bf0389e82b70592060b7e25195d8d1728d7a0b76b549e92f7f5d124b4" +version = "0.0.2" +source = "git+https://github.com/shotover/tree-sitter-cql.git?branch=main#23e030fe4d6abba2aa15d840825330fc9f1d1d12" dependencies = [ "cc", "regex", diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml index 3b7afbee9..ecffb8f88 100644 --- a/shotover-proxy/Cargo.toml +++ b/shotover-proxy/Cargo.toml @@ -43,7 +43,7 @@ thiserror = "1.0" anyhow = "1.0.31" # Parsers -cql3-parser = { git = "https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git", branch="verify_error" } +cql3-parser = { git = "https://github.com/Claude-at-Instaclustr/rust_cql3_parser.git", branch="shotover_version" } serde = { version = "1.0.111", features = ["derive"] } serde_json = "1.0" serde_yaml = "0.8.21" diff --git a/shotover-proxy/src/frame/cassandra.rs b/shotover-proxy/src/frame/cassandra.rs index 22561f119..a1a419fb0 100644 --- a/shotover-proxy/src/frame/cassandra.rs +++ b/shotover-proxy/src/frame/cassandra.rs @@ -22,7 +22,7 @@ use cassandra_protocol::types::cassandra_type::CassandraType; use cassandra_protocol::types::{CBytes, CBytesShort, CInt, CLong}; use cql3_parser::cassandra_ast::CassandraAST; use cql3_parser::cassandra_statement::CassandraStatement; -use cql3_parser::common::{FQName, Operand}; +use cql3_parser::common::{FQName, Identifier, Operand}; use itertools::Itertools; use nonzero_ext::nonzero; use sodiumoxide::hex; @@ -308,7 +308,7 @@ impl CassandraFrame { } /// returns a list of unique keyspace (namespace) from the table names in the statement(s). - pub fn namespace(&mut self) -> Vec { + pub fn namespace(&mut self) -> Vec { self.get_table_names() .into_iter() .filter_map(|fq_name| fq_name.keyspace.clone()) @@ -855,7 +855,7 @@ mod test { use cassandra_protocol::types::cassandra_type::CassandraType; use cassandra_protocol::types::prelude::Blob; use cql3_parser::cassandra_statement::CassandraStatement; - use cql3_parser::common::Operand; + use cql3_parser::common::{Identifier, Operand}; use std::net::IpAddr; use std::str::FromStr; use uuid::Uuid; @@ -1136,11 +1136,11 @@ mod test { pub fn test_to_cassandra_type_for_misc_operands() { assert_eq!( CassandraType::Ascii("Hello".to_string()), - Operand::Column("Hello".to_string()).as_cassandra_type() + Operand::Column(Identifier::parse("Hello")).as_cassandra_type() ); assert_eq!( CassandraType::Ascii("Hello".to_string()), - Operand::Func("Hello".to_string()).as_cassandra_type() + Operand::Func(Identifier::parse("Hello")).as_cassandra_type() ); assert_eq!(CassandraType::Null, Operand::Null.as_cassandra_type()); assert_eq!( diff --git a/shotover-proxy/src/message/mod.rs b/shotover-proxy/src/message/mod.rs index 580b131e1..3c3c365c5 100644 --- a/shotover-proxy/src/message/mod.rs +++ b/shotover-proxy/src/message/mod.rs @@ -19,7 +19,7 @@ use cassandra_protocol::{ CBytes, }, }; -use cql3_parser::common::{DataTypeName, Operand}; +use cql3_parser::common::{DataTypeName, Identifier, Operand}; use itertools::Itertools; use nonzero_ext::nonzero; use num::BigInt; @@ -199,7 +199,7 @@ impl Message { } /// Returns None when fails to parse the message - pub fn namespace(&mut self) -> Option> { + pub fn namespace(&mut self) -> Option> { match self.frame()? { Frame::Cassandra(cassandra) => Some(cassandra.namespace()), Frame::Redis(_) => unimplemented!(), diff --git a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs index 645841c29..ed7ee526c 100644 --- a/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover-proxy/src/transforms/cassandra/peers_rewrite.rs @@ -7,7 +7,7 @@ use crate::{ use anyhow::Result; use async_trait::async_trait; use cql3_parser::cassandra_statement::CassandraStatement; -use cql3_parser::common::FQName; +use cql3_parser::common::{FQName, Identifier}; use cql3_parser::select::SelectElement; use serde::Deserialize; use std::collections::HashMap; @@ -45,7 +45,7 @@ impl Transform for CassandraPeersRewrite { async fn transform<'a>(&'a mut self, mut message_wrapper: Wrapper<'a>) -> ChainResponse { // Find the indices of queries to system.peers & system.peers_v2 // we need to know which columns in which CQL queries in which messages have system peers - let column_names: HashMap> = message_wrapper + let column_names: HashMap> = message_wrapper .messages .iter_mut() .enumerate() @@ -71,8 +71,9 @@ impl Transform for CassandraPeersRewrite { /// determine if the message contains a SELECT from `system.peers_v2` that includes the `native_port` column /// return a list of column names (or their alias) for each `native_port`. -fn extract_native_port_column(peer_table: &FQName, message: &mut Message) -> Vec { - let mut result: Vec = vec![]; +fn extract_native_port_column(peer_table: &FQName, message: &mut Message) -> Vec { + let mut result: Vec = vec![]; + let native_port = Identifier::parse("native_port"); if let Some(Frame::Cassandra(cassandra)) = message.frame() { if let CassandraOperation::Query { query, .. } = &cassandra.operation { for statement in &query.statements { @@ -81,11 +82,11 @@ fn extract_native_port_column(peer_table: &FQName, message: &mut Message) -> Vec for select_element in &select.columns { match select_element { SelectElement::Column(col_name) => { - if col_name.name == "native_port" { - result.push(col_name.alias_or_name().to_string()); + if col_name.name == native_port { + result.push(col_name.alias_or_name().clone()); } } - SelectElement::Star => result.push("native_port".to_string()), + SelectElement::Star => result.push(native_port.clone()), _ => {} } } @@ -99,7 +100,7 @@ fn extract_native_port_column(peer_table: &FQName, message: &mut Message) -> Vec /// Rewrite the `native_port` field in the results from a query to `system.peers_v2` table /// Only Cassandra queries to the `system.peers` table found via the `is_system_peers` function should be passed to this -fn rewrite_port(message: &mut Message, column_names: &[String], new_port: u32) { +fn rewrite_port(message: &mut Message, column_names: &[Identifier], new_port: u32) { if let Some(Frame::Cassandra(frame)) = message.frame() { if let CassandraOperation::Result(CassandraResult::Rows { value, metadata }) = &mut frame.operation @@ -109,7 +110,7 @@ fn rewrite_port(message: &mut Message, column_names: &[String], new_port: u32) { .iter() .enumerate() .filter_map(|(idx, col)| { - if column_names.contains(&col.name) { + if column_names.contains(&Identifier::parse(&col.name)) { Some(idx) } else { None @@ -198,12 +199,15 @@ mod test { #[test] fn test_is_system_peers_v2() { + let native_port = Identifier::parse("native_port"); + let foo = Identifier::parse("foo"); + let peer_table = FQName::new("system", "peers_v2"); let v = extract_native_port_column( &peer_table, &mut create_query_message("SELECT * FROM system.peers_v2;"), ); - assert_eq!(vec!("native_port".to_string()), v); + assert_eq!(vec!(native_port.clone()), v); let v = extract_native_port_column( &peer_table, @@ -215,7 +219,7 @@ mod test { &peer_table, &mut create_query_message("SELECT native_port as foo from system.peers_v2"), ); - assert_eq!(vec!("foo".to_string()), v); + assert_eq!(vec!(foo.clone()), v); let v = extract_native_port_column( &peer_table, @@ -223,7 +227,7 @@ mod test { "SELECT native_port as foo, native_port from system.peers_v2", ), ); - assert_eq!(vec!["foo".to_string(), "native_port".to_string()], v); + assert_eq!(vec![foo, native_port], v); } #[test] @@ -246,7 +250,7 @@ mod test { ], ); - rewrite_port(&mut message, &["native_port".to_string()], 9043); + rewrite_port(&mut message, &[Identifier::parse("native_port")], 9043); let expected = create_response_message( &col_spec, @@ -288,7 +292,10 @@ mod test { rewrite_port( &mut original, - &["native_port".to_string(), "alias_port".to_string()], + &[ + Identifier::parse("native_port"), + Identifier::parse("alias_port"), + ], 9043, ); @@ -358,7 +365,10 @@ mod test { rewrite_port( &mut original, - &["native_port".to_string(), "alias_port".to_string()], + &[ + Identifier::parse("native_port"), + Identifier::parse("alias_port"), + ], 9043, ); diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index 5543283ca..adc0656f2 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -9,7 +9,7 @@ use anyhow::Result; use async_trait::async_trait; use bytes::Bytes; use cql3_parser::cassandra_statement::CassandraStatement; -use cql3_parser::common::{FQName, Operand}; +use cql3_parser::common::{FQName, Identifier, Operand}; use cql3_parser::insert::InsertValues; use cql3_parser::select::{Select, SelectElement}; use serde::{Deserialize, Serialize}; @@ -26,7 +26,8 @@ mod pkcs_11; #[derive(Clone)] pub struct Protect { - keyspace_table_columns: HashMap>>, + /// map of keyspace Identifiers to map of table Identifiers to column Identifiers + keyspace_table_columns: HashMap>>, key_source: KeyManager, // TODO this should be a function to create key_ids based on "something", e.g. primary key // for the moment this is just a string @@ -52,11 +53,11 @@ impl Protect { /// get the list of protected columns for the specified table name. Will return `None` if no columns /// are defined for the table. - fn get_protected_columns(&self, table_name: &FQName) -> Option<&Vec> { + fn get_protected_columns(&self, table_name: &FQName) -> Option<&Vec> { // TODO replace "" with cached keyspace name if let Some(tables) = self .keyspace_table_columns - .get(table_name.extract_keyspace("")) + .get(table_name.extract_keyspace(&Identifier::default())) { tables.get(&table_name.name) } else { @@ -149,7 +150,7 @@ impl Protect { async fn process_select( &self, select: &Select, - columns: &[String], + columns: &[Identifier], rows: &mut Vec>, ) -> Result { let mut modified = false; @@ -197,7 +198,7 @@ pub struct KeyMaterial { #[derive(Deserialize, Debug, Clone)] pub struct ProtectConfig { - pub keyspace_table_columns: HashMap>>, + pub keyspace_table_columns: HashMap>>, pub key_manager: KeyManagerConfig, } @@ -390,7 +391,7 @@ mod test { use crate::transforms::protect::{Protect, Protected}; use bytes::Bytes; use cql3_parser::cassandra_statement::CassandraStatement; - use cql3_parser::common::Operand; + use cql3_parser::common::{Identifier, Operand}; use cql3_parser::insert::InsertValues; use sodiumoxide::crypto::secretbox::Nonce; use std::collections::HashMap; @@ -441,11 +442,11 @@ mod test { kek_id: "".to_string(), }; - let cols = vec!["col1".to_string()]; + let cols = vec![Identifier::parse("col1")]; let mut tables = HashMap::new(); - tables.insert("test_table".to_string(), cols.clone()); + tables.insert(Identifier::parse("test_table"), cols.clone()); let mut keyspace_table_columns = HashMap::new(); - keyspace_table_columns.insert("".to_string(), tables); + keyspace_table_columns.insert(Default::default(), tables); let protect = Protect { keyspace_table_columns, key_source: KeyManager::Local(local_key_mgr), diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index c5830e012..d51dbc21c 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -13,8 +13,8 @@ use bytes::Bytes; use cassandra_protocol::frame::Serialize; use cassandra_protocol::frame::Version; use cql3_parser::cassandra_statement::CassandraStatement; -use cql3_parser::common::{Operand, RelationElement, RelationOperator}; -use cql3_parser::select::{Named, Select, SelectElement}; +use cql3_parser::common::{FQName, Identifier, Operand, RelationElement, RelationOperator}; +use cql3_parser::select::Select; use itertools::Itertools; use metrics::{register_counter, Counter}; use serde::Deserialize; @@ -38,10 +38,10 @@ metadata - serialized form of the metadata from cassandra. */ enum CacheableState { - Read { table_name: String }, - Update { table_name: String }, - Delete { table_name: String }, - Drop { table_name: String }, + Read { table_name: FQName }, + Update { table_name: FQName }, + Delete { table_name: FQName }, + Drop { table_name: FQName }, Skip { reason: String }, Err { reason: String }, } @@ -73,24 +73,48 @@ impl Display for CacheableState { #[derive(Deserialize, Debug, Clone)] pub struct RedisConfig { - pub caching_schema: HashMap, + pub caching_schema: HashMap, pub chain: Vec, } #[derive(Deserialize, Debug, Clone)] -pub struct TableCacheSchema { +pub struct TableCacheSchemaConfig { partition_key: Vec, range_key: Vec, } +#[derive(Deserialize, Debug, Clone)] +pub struct TableCacheSchema { + partition_key: Vec, + range_key: Vec, +} + +impl From<&TableCacheSchemaConfig> for TableCacheSchema { + fn from(cfg: &TableCacheSchemaConfig) -> Self { + TableCacheSchema { + partition_key: cfg + .partition_key + .iter() + .map(|s| Identifier::parse(s)) + .collect(), + range_key: cfg.range_key.iter().map(|s| Identifier::parse(s)).collect(), + } + } +} + impl RedisConfig { pub async fn get_transform(&self, topics: &TopicHolder) -> Result { let missed_requests = register_counter!("cache_miss"); + let mut caching_schema: HashMap = HashMap::new(); + self.caching_schema.iter().for_each(|(k, v)| { + caching_schema.insert(FQName::parse(k), v.into()); + }); + Ok(Transforms::RedisCache(SimpleRedisCache { cache_chain: build_chain_from_config("cache_chain".to_string(), &self.chain, topics) .await?, - caching_schema: self.caching_schema.clone(), + caching_schema, missed_requests, })) } @@ -99,7 +123,7 @@ impl RedisConfig { #[derive(Clone)] pub struct SimpleRedisCache { cache_chain: TransformChain, - caching_schema: HashMap, + caching_schema: HashMap, missed_requests: Counter, } @@ -122,9 +146,7 @@ impl SimpleRedisCache { let mut state = is_cacheable(statement); if let CacheableState::Read { table_name } = &mut state { debug!("build_cache_query processing cacheable state"); - if let Some(table_cache_schema) = - self.caching_schema.get(table_name.as_str()) - { + if let Some(table_cache_schema) = self.caching_schema.get(table_name) { match build_redis_key_from_cql3(statement, table_cache_schema) { Ok((redis_key, hash_key)) => { trace!( @@ -368,7 +390,7 @@ impl SimpleRedisCache { } }) .next(); - let result_messages = &mut message_wrapper.call_next_transform().await?; + let mut result_messages = message_wrapper.call_next_transform().await?; if let Some(orig_cql) = orig_cql { let mut cache_messages: Vec = vec![]; for (response, statement) in result_messages.iter_mut().zip(orig_cql.statements.iter()) @@ -397,9 +419,7 @@ impl SimpleRedisCache { self.clear_table_cache(); } CacheableState::Read { table_name } => { - if let Some(table_cache_schema) = - self.caching_schema.get(table_name.as_str()) - { + if let Some(table_cache_schema) = self.caching_schema.get(&table_name) { if let Ok((redis_key, hash_key)) = build_redis_key_from_cql3(statement, table_cache_schema) { @@ -439,7 +459,7 @@ impl SimpleRedisCache { } } } - Ok(result_messages.to_vec()) + Ok(result_messages) } } @@ -458,7 +478,7 @@ fn is_cacheable(statement: &CassandraStatement) -> CacheableState { CassandraStatement::Select(select) => { if has_params { CacheableState::Delete { - table_name: table_name.into(), + table_name: table_name.clone(), } } else if select.filtering { CacheableState::Skip { @@ -470,28 +490,28 @@ fn is_cacheable(statement: &CassandraStatement) -> CacheableState { } } else { CacheableState::Read { - table_name: table_name.to_string(), + table_name: table_name.clone(), } } } CassandraStatement::Insert(insert) => { if has_params || insert.if_not_exists { CacheableState::Delete { - table_name: table_name.into(), + table_name: table_name.clone(), } } else { CacheableState::Update { - table_name: table_name.into(), + table_name: table_name.clone(), } } } CassandraStatement::DropTable(_) => CacheableState::Drop { - table_name: table_name.into(), + table_name: table_name.clone(), }, CassandraStatement::Update(update) => { if has_params || update.if_exists { CacheableState::Delete { - table_name: table_name.into(), + table_name: table_name.clone(), } } else { for assignment_element in &update.assignments { @@ -501,7 +521,7 @@ fn is_cacheable(statement: &CassandraStatement) -> CacheableState { update.table_name, assignment_element.name ); return CacheableState::Delete { - table_name: table_name.into(), + table_name: table_name.clone(), }; } if assignment_element.name.idx.is_some() { @@ -510,12 +530,12 @@ fn is_cacheable(statement: &CassandraStatement) -> CacheableState { update.table_name, assignment_element.name ); return CacheableState::Delete { - table_name: table_name.into(), + table_name: table_name.clone(), }; } } CacheableState::Update { - table_name: table_name.into(), + table_name: table_name.clone(), } } } @@ -536,14 +556,13 @@ fn is_cacheable(statement: &CassandraStatement) -> CacheableState { /// the cassandra range key (may be partially specified) fn build_query_redis_key_from_value_map( table_cache_schema: &TableCacheSchema, - query_values: &BTreeMap>, + query_values: &BTreeMap>, table_name: &str, ) -> Result { let mut key = table_name.as_bytes().to_vec(); - for c_name in &table_cache_schema.partition_key { - let column_name = c_name.to_lowercase(); + for column_name in &table_cache_schema.partition_key { debug!("processing partition key segment: {}", column_name); - match query_values.get(column_name.as_str()) { + match query_values.get(&Operand::Column(column_name.clone())) { None => { return Err(CacheableState::Skip { reason: format!( @@ -572,9 +591,8 @@ fn build_query_redis_key_from_value_map( } let mut skipping = false; - for c_name in &table_cache_schema.range_key { - let column_name = c_name.to_lowercase(); - match query_values.get(column_name.as_str()) { + for column_name in &table_cache_schema.range_key { + match query_values.get(&Operand::Column(column_name.clone())) { None => { skipping = true; } @@ -611,37 +629,21 @@ fn build_query_redis_key_from_value_map( /// the cassandra range key (may be partially specified) fn build_query_redis_hash_from_value_map( table_cache_schema: &TableCacheSchema, - query_values: &BTreeMap>, + query_values: &BTreeMap>, select: &Select, ) -> Result { let mut my_values = query_values.clone(); - for c_name in &table_cache_schema.partition_key { - let column_name = c_name.to_lowercase(); - my_values.remove(&column_name); + for column_name in &table_cache_schema.partition_key { + my_values.remove(&Operand::Column(column_name.clone())); } - for c_name in &table_cache_schema.range_key { - let column_name = c_name.to_lowercase(); - my_values.remove(&column_name); + for column_name in &table_cache_schema.range_key { + my_values.remove(&Operand::Column(column_name.clone())); } let mut str = if select.columns.is_empty() { String::from("WHERE ") } else { - let mut tmp = select - .columns - .iter() - .map(|select_element| match select_element { - SelectElement::Star => SelectElement::Star, - SelectElement::Column(named) => SelectElement::Column(Named { - name: named.name.to_lowercase(), - alias: named.alias.as_ref().map(|name| name.to_lowercase()), - }), - SelectElement::Function(named) => SelectElement::Function(Named { - name: named.name.to_lowercase(), - alias: named.alias.as_ref().map(|name| name.to_lowercase()), - }), - }) - .join(", "); + let mut tmp = select.columns.iter().join(", "); tmp.push_str(" WHERE "); tmp }; @@ -658,16 +660,15 @@ fn build_query_redis_hash_from_value_map( } fn populate_value_map_from_where_clause( - value_map: &mut BTreeMap>, + value_map: &mut BTreeMap>, where_clause: &[RelationElement], ) { for relation_element in where_clause { - let column_name = relation_element.obj.to_string().to_lowercase(); - let value = value_map.get_mut(column_name.as_str()); + let value = value_map.get_mut(&relation_element.obj); if let Some(vec) = value { vec.push(relation_element.clone()) } else { - value_map.insert(column_name, vec![relation_element.clone()]); + value_map.insert(relation_element.obj.clone(), vec![relation_element.clone()]); }; } } @@ -676,7 +677,7 @@ fn build_redis_key_from_cql3( statement: &CassandraStatement, table_cache_schema: &TableCacheSchema, ) -> Result<(Bytes, Bytes), CacheableState> { - let mut value_map: BTreeMap> = BTreeMap::new(); + let mut value_map: BTreeMap> = BTreeMap::new(); match statement { CassandraStatement::Select(select) => { populate_value_map_from_where_clause(&mut value_map, &select.where_clause); @@ -691,18 +692,18 @@ fn build_redis_key_from_cql3( } CassandraStatement::Insert(insert) => { - for (c_name, operand) in insert.get_value_map().into_iter() { - let column_name = c_name.to_lowercase(); + for (column_name, operand) in insert.get_value_map().into_iter() { let relation_element = RelationElement { obj: Operand::Column(column_name.clone()), oper: RelationOperator::Equal, value: operand.clone(), }; - let value = value_map.get_mut(column_name.as_str()); + let key = Operand::Column(column_name.clone()); + let value = value_map.get_mut(&key); if let Some(vec) = value { vec.push(relation_element) } else { - value_map.insert(column_name, vec![relation_element]); + value_map.insert(key, vec![relation_element]); }; } Ok(( @@ -811,6 +812,7 @@ mod test { use crate::transforms::{Transform, Transforms}; use bytes::{Bytes, BytesMut}; use cql3_parser::cassandra_statement::CassandraStatement; + use cql3_parser::common::Identifier; use metrics::register_counter; use std::collections::HashMap; @@ -825,8 +827,8 @@ mod test { #[test] fn equal_test() { let table_cache_schema = TableCacheSchema { - partition_key: vec!["z".to_string()], - range_key: vec!["x".to_string(), "y".to_string()], + partition_key: vec![Identifier::parse("z")], + range_key: vec![Identifier::parse("x"), Identifier::parse("y")], }; let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x = 123 AND y = 965"); @@ -842,7 +844,7 @@ mod test { #[test] fn insert_simple_test() { let table_cache_schema = TableCacheSchema { - partition_key: vec!["z".to_string()], + partition_key: vec![Identifier::parse("z")], range_key: vec![], }; @@ -859,8 +861,8 @@ mod test { #[test] fn insert_simple_clustering_test() { let table_cache_schema = TableCacheSchema { - partition_key: vec!["z".to_string()], - range_key: vec!["c".to_string()], + partition_key: vec![Identifier::parse("z")], + range_key: vec![Identifier::parse("c")], }; let ast = build_query("INSERT INTO foo (z, c, v) VALUES (1, 'yo' , 123)"); @@ -875,7 +877,7 @@ mod test { #[test] fn update_simple_clustering_test() { let table_cache_schema = TableCacheSchema { - partition_key: vec!["z".to_string()], + partition_key: vec![Identifier::parse("z")], range_key: vec![], }; @@ -891,8 +893,8 @@ mod test { #[test] fn check_deterministic_order_test() { let table_cache_schema = TableCacheSchema { - partition_key: vec!["z".to_string()], - range_key: vec!["x".to_string(), "y".to_string()], + partition_key: vec![Identifier::parse("z")], + range_key: vec![Identifier::parse("x"), Identifier::parse("y")], }; let ast = build_query("SELECT * FROM foo WHERE z = 1 AND x = 123 AND y = 965"); @@ -915,7 +917,7 @@ mod test { #[test] fn range_exclusive_test() { let table_cache_schema = TableCacheSchema { - partition_key: vec!["z".to_string()], + partition_key: vec![Identifier::parse("z")], range_key: vec![], }; @@ -932,7 +934,7 @@ mod test { #[test] fn range_inclusive_test() { let table_cache_schema = TableCacheSchema { - partition_key: vec!["z".to_string()], + partition_key: vec![Identifier::parse("z")], range_key: vec![], }; @@ -949,7 +951,7 @@ mod test { #[test] fn single_pk_only_test() { let table_cache_schema = TableCacheSchema { - partition_key: vec!["id".to_string()], + partition_key: vec![Identifier::parse("id")], range_key: vec![], }; @@ -970,7 +972,7 @@ mod test { #[test] fn compound_pk_only_test() { let table_cache_schema = TableCacheSchema { - partition_key: vec!["z".to_string(), "y".to_string()], + partition_key: vec![Identifier::parse("z"), Identifier::parse("y")], range_key: vec![], }; @@ -987,7 +989,7 @@ mod test { #[test] fn open_range_test() { let table_cache_schema = TableCacheSchema { - partition_key: vec!["z".to_string()], + partition_key: vec![Identifier::parse("z")], range_key: vec![], }; From 3b2ea5299713968eb872fe670eed2c6f9192b796 Mon Sep 17 00:00:00 2001 From: Claude Warren Date: Mon, 23 May 2022 05:54:50 +0100 Subject: [PATCH 60/60] fixes as per review --- shotover-proxy/src/transforms/protect/mod.rs | 11 ++--- shotover-proxy/src/transforms/redis/cache.md | 17 +++++++ shotover-proxy/src/transforms/redis/cache.rs | 47 +++++--------------- 3 files changed, 32 insertions(+), 43 deletions(-) diff --git a/shotover-proxy/src/transforms/protect/mod.rs b/shotover-proxy/src/transforms/protect/mod.rs index adc0656f2..d6402879e 100644 --- a/shotover-proxy/src/transforms/protect/mod.rs +++ b/shotover-proxy/src/transforms/protect/mod.rs @@ -35,7 +35,7 @@ pub struct Protect { } impl Protect { - /// encodes a Protected object into a byte array. This is here to centeralize the serde for + /// encodes a Protected object into a byte array. This is here to centralize the serde for /// the Protected object. /// Returns an error if a Plaintext Protected object is passed fn encode(protected: &Protected) -> Result> { @@ -54,7 +54,7 @@ impl Protect { /// get the list of protected columns for the specified table name. Will return `None` if no columns /// are defined for the table. fn get_protected_columns(&self, table_name: &FQName) -> Option<&Vec> { - // TODO replace "" with cached keyspace name + // TODO replace `Identifier::default()` with cached keyspace name if let Some(tables) = self .keyspace_table_columns .get(table_name.extract_keyspace(&Identifier::default())) @@ -65,7 +65,7 @@ impl Protect { } } - /// extractes the protected object from the message value. Resulting object is a Protected::Ciphertext + /// extracts the protected object from the message value. Resulting object is a Protected::Ciphertext fn extract_protected(&self, value: &MessageValue) -> Result { match value { MessageValue::Bytes(b) => Protect::decode(&b[..]), @@ -82,9 +82,6 @@ impl Protect { /// determines if columns in the CassandraStatement need to be encrypted and encrypts them. Returns `true` if any columns were changed. /// * `statement` the statement to encrypt. - /// * `columns` the column names to encrypt. - /// * `key_source` the key manager with encryption keys. - /// * `key_id` the key within the manager to use. async fn encrypt_columns(&self, statement: &mut CassandraStatement) -> Result { let mut data_changed = false; if let Some(table_name) = cql_statement::get_table_name(statement) { @@ -235,7 +232,6 @@ fn decrypt(ciphertext: Vec, nonce: Nonce, sym_key: &Key) -> Result for MessageValue { fn from(p: Protected) -> Self { match p { @@ -243,7 +239,6 @@ impl From for MessageValue { "tried to move unencrypted value to plaintext without explicitly calling decrypt" ), Protected::Ciphertext { .. } => { - //MessageValue::Bytes(Bytes::from(serde_json::to_vec(&p).unwrap())) MessageValue::Bytes(Bytes::from(bincode::serialize(&p).unwrap())) } } diff --git a/shotover-proxy/src/transforms/redis/cache.md b/shotover-proxy/src/transforms/redis/cache.md index 88b85265f..e3577531f 100644 --- a/shotover-proxy/src/transforms/redis/cache.md +++ b/shotover-proxy/src/transforms/redis/cache.md @@ -121,3 +121,20 @@ results may modify the cache. Any command with a cacheable state of Read, Update or Delete is processed again on return and the cache updated appropriately. +### Conceptual process flow + +Conceptually the code has 4 vectors of messages, each vec can be considered its own stage of processing. + 1. `messages_cass_request`: + * the cassandra requests that the transform receives. + 2. `messages_redis_request`: + * each query in each cassandra request in `messages_cass_request`, if it is cacheable, is transformed into a redis request + * each request gets sent to the redis server + 3. `messages_redis_response`: + * the redis responses we get back from the server + 4. `messages_cass_response`: + * for cached queries we return the result from the cache. + * otherwise we forward the request to the source and return the result while monitoring for: + * Queries that cause eviction of data from the cache: delete, insert, update, etc. + * Queries that update data in the cache: Select. + + diff --git a/shotover-proxy/src/transforms/redis/cache.rs b/shotover-proxy/src/transforms/redis/cache.rs index d51dbc21c..f4fea76c1 100644 --- a/shotover-proxy/src/transforms/redis/cache.rs +++ b/shotover-proxy/src/transforms/redis/cache.rs @@ -230,39 +230,34 @@ impl SimpleRedisCache { match redis_response.frame() { Some(Frame::Redis(redis_frame)) => { match redis_frame { - RedisFrame::SimpleString(_) => Err(CacheableState::Err { - reason: "Redis returned a simple string".into(), - }), RedisFrame::Error(e) => { return Err(CacheableState::Err { reason: e.to_string(), }) } - RedisFrame::Integer(_) => Err(CacheableState::Err { - reason: "Redis returned an int value".into(), - }), RedisFrame::BulkString(redis_bytes) => { // Redis response contains serialized version of result struct from CassandraOperation::Result( result ) let mut cursor = Cursor::new(redis_bytes.as_ref()); - let answer = - CassandraResult::from_cursor(&mut cursor, Version::V4); - if let Ok(result) = answer { - Ok(result) - } else { - Err(CacheableState::Err { - reason: answer.unwrap_err().to_string(), + CassandraResult::from_cursor(&mut cursor, Version::V4) + .map_err(|err| CacheableState::Err { + reason: err.to_string(), }) - } } - RedisFrame::Array(_) => Err(CacheableState::Err { - reason: "Redis returned an array value".into(), - }), RedisFrame::Null => { self.missed_requests.increment(1); Err(CacheableState::Skip { reason: "No cache results".into(), }) } + RedisFrame::SimpleString(_) => Err(CacheableState::Err { + reason: "Redis returned a simple string".into(), + }), + RedisFrame::Integer(_) => Err(CacheableState::Err { + reason: "Redis returned an int value".into(), + }), + RedisFrame::Array(_) => Err(CacheableState::Err { + reason: "Redis returned an array value".into(), + }), } } @@ -300,24 +295,6 @@ impl SimpleRedisCache { &mut self, mut cassandra_messages: Messages, ) -> Result { - // This function is a little hard to follow, so here's an overview. - // We have 4 vecs of messages, each vec can be considered its own stage of processing. - // 1. messages_cass_request: - // * the cassandra requests that the function receives. - // 2. messages_redis_request: - // * each query in each cassandra request in messages_cass_request is transformed into a redis request - // * each request gets sent to the redis server - // 3. messages_redis_response: - // * the redis responses we get back from the server - // 4. messages_cass_response: - // * Well messages_cass_response is what we would have called this, in reality we reuse the messages_cass_request vec because its cheaper. - // * To create each response we go through each request in messages_cass_request: - // + if the request is a CassandraOperation::Batch then we consume a message from messages_redis_response for each query in the batch - // - if any of the messages are errors then generate a cassandra ERROR otherwise generate a VOID RESULT. - // - we can get away with this because batches can only contain INSERT/UPDATE/DELETE and therefore always contain either an ERROR or a VOID RESULT - // + if the request is a CassandraOperation::Query then we consume a single message from messages_redis_response converting it to a cassandra response - // * These are the cassandra responses that we return from the function. - debug!("read_from_cache called"); // build the cache query