From 604bf995fb1156ad4922c2f5117a4f02852b407a Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Mon, 1 May 2023 20:43:22 -0700 Subject: [PATCH 01/11] Some queries --- Cargo.lock | 14 +++- Cargo.toml | 4 +- src/client.rs | 8 ++- src/config.rs | 1 + src/errors.rs | 3 + src/pool.rs | 2 + src/query_router.rs | 153 ++++++++++++++++++++++++++++++-------------- 7 files changed, 133 insertions(+), 52 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dad2e6c5..6a3e0011 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -882,7 +882,7 @@ checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" [[package]] name = "pgcat" -version = "1.0.1" +version = "1.0.2-alpha1" dependencies = [ "arc-swap", "async-trait", @@ -1297,6 +1297,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "355dc4d4b6207ca8a3434fc587db0a8016130a574dbcdbfb93d7f7b5bc5b211a" dependencies = [ "log", + "sqlparser_derive", +] + +[[package]] +name = "sqlparser_derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55fe75cb4a364c7f7ae06c7dbbc8d84bddd85d6cdf9975963c3935bc1991761e" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 436c3dd5..0d5c6446 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgcat" -version = "1.0.1" +version = "1.0.2-alpha1" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -19,7 +19,7 @@ serde_derive = "1" regex = "1" num_cpus = "1" once_cell = "1" -sqlparser = "0.33.0" +sqlparser = {version = "0.33", features = ["visitor"] } log = "0.4" arc-swap = "1" env_logger = "0.10" diff --git a/src/client.rs b/src/client.rs index efde7554..de2213a7 100644 --- a/src/client.rs +++ b/src/client.rs @@ -815,7 +815,9 @@ where 'Q' => { if query_router.query_parser_enabled() { - query_router.infer(&message); + if let Ok(ast) = QueryRouter::parse(&message) { + let _ = query_router.infer(&ast); + } } } @@ -823,7 +825,9 @@ where self.buffer.put(&message[..]); if query_router.query_parser_enabled() { - query_router.infer(&message); + if let Ok(ast) = QueryRouter::parse(&message) { + let _ = query_router.infer(&ast); + } } continue; diff --git a/src/config.rs b/src/config.rs index fd7d3912..9d66e48d 100644 --- a/src/config.rs +++ b/src/config.rs @@ -298,6 +298,7 @@ pub struct General { pub admin_username: String, pub admin_password: String, + // Support for auth query pub auth_query: Option, pub auth_query_user: Option, pub auth_query_password: Option, diff --git a/src/errors.rs b/src/errors.rs index fb70c042..f5fe21c3 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -24,6 +24,9 @@ pub enum Error { ParseBytesError(String), AuthError(String), AuthPassthroughError(String), + UnsupportedStatement, + QueryRouterParserError(String), + PermissionDeniedTable(String), } #[derive(Clone, PartialEq, Debug)] diff --git a/src/pool.rs b/src/pool.rs index ee8de446..4a735e0f 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -395,6 +395,8 @@ impl ConnectionPool { ); } + debug!("Query router: {}", pool_config.query_parser_enabled); + let pool = ConnectionPool { databases: shards, stats: pool_stats, diff --git a/src/query_router.rs b/src/query_router.rs index 5b2ba0c4..c3177e4e 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -12,6 +12,7 @@ use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::Parser; use crate::config::Role; +use crate::errors::Error; use crate::messages::BytesMutReader; use crate::pool::PoolSettings; use crate::sharding::Sharder; @@ -324,10 +325,7 @@ impl QueryRouter { Some((command, value)) } - /// Try to infer which server to connect to based on the contents of the query. - pub fn infer(&mut self, message: &BytesMut) -> bool { - debug!("Inferring role"); - + pub fn parse(message: &BytesMut) -> Result, Error> { let mut message_cursor = Cursor::new(message); let code = message_cursor.get_u8() as char; @@ -353,28 +351,33 @@ impl QueryRouter { query } - _ => return false, + _ => return Err(Error::UnsupportedStatement), }; - let ast = match Parser::parse_sql(&PostgreSqlDialect {}, &query) { - Ok(ast) => ast, + match Parser::parse_sql(&PostgreSqlDialect {}, &query) { + Ok(ast) => { + debug!("AST: {:?}", ast); + Ok(ast) + } + Err(err) => { - // SELECT ... FOR UPDATE won't get parsed correctly. debug!("{}: {}", err, query); - self.active_role = Some(Role::Primary); - return false; + Err(Error::QueryRouterParserError(err.to_string())) } - }; + } + } - debug!("AST: {:?}", ast); + /// Try to infer which server to connect to based on the contents of the query. + pub fn infer(&mut self, ast: &Vec) -> Result<(), Error> { + debug!("Inferring role"); if ast.is_empty() { // That's weird, no idea, let's go to primary self.active_role = Some(Role::Primary); - return false; + return Err(Error::QueryRouterParserError("empty query".into())); } - for q in &ast { + for q in ast { match q { // All transactions go to the primary, probably a write. StartTransaction { .. } => { @@ -418,7 +421,7 @@ impl QueryRouter { }; } - true + Ok(()) } /// Parse the shard number from the Bind message @@ -862,7 +865,7 @@ mod test { for query in queries { // It's a recognized query - assert!(qr.infer(&query)); + assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok()); assert_eq!(qr.role(), Some(Role::Replica)); } } @@ -881,7 +884,7 @@ mod test { for query in queries { // It's a recognized query - assert!(qr.infer(&query)); + assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok()); assert_eq!(qr.role(), Some(Role::Primary)); } } @@ -893,7 +896,7 @@ mod test { let query = simple_query("SELECT * FROM items WHERE id = 5"); assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None); - assert!(qr.infer(&query)); + assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok()); assert_eq!(qr.role(), None); } @@ -913,7 +916,7 @@ mod test { res.put(prepared_stmt); res.put_i16(0); - assert!(qr.infer(&res)); + assert!(qr.infer(&QueryRouter::parse(&res).unwrap()).is_ok()); assert_eq!(qr.role(), Some(Role::Replica)); } @@ -1077,11 +1080,11 @@ mod test { assert_eq!(qr.role(), None); let query = simple_query("INSERT INTO test_table VALUES (1)"); - assert!(qr.infer(&query)); + assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok()); assert_eq!(qr.role(), Some(Role::Primary)); let query = simple_query("SELECT * FROM test_table"); - assert!(qr.infer(&query)); + assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok()); assert_eq!(qr.role(), Some(Role::Replica)); assert!(qr.query_parser_enabled()); @@ -1142,15 +1145,24 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); - assert!(qr.infer(&simple_query("BEGIN; SELECT 1; COMMIT;"))); + assert!(qr + .infer(&QueryRouter::parse(&simple_query("BEGIN; SELECT 1; COMMIT;")).unwrap()) + .is_ok()); assert_eq!(qr.role(), Role::Primary); - assert!(qr.infer(&simple_query("SELECT 1; SELECT 2;"))); + assert!(qr + .infer(&QueryRouter::parse(&simple_query("SELECT 1; SELECT 2;")).unwrap()) + .is_ok()); assert_eq!(qr.role(), Role::Replica); - assert!(qr.infer(&simple_query( - "SELECT 123; INSERT INTO t VALUES (5); SELECT 1;" - ))); + assert!(qr + .infer( + &QueryRouter::parse(&simple_query( + "SELECT 123; INSERT INTO t VALUES (5); SELECT 1;" + )) + .unwrap() + ) + .is_ok()); assert_eq!(qr.role(), Role::Primary); } @@ -1208,47 +1220,84 @@ mod test { qr.pool_settings.automatic_sharding_key = Some("data.id".to_string()); qr.pool_settings.shards = 3; - assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 5"))); + assert!(qr + .infer(&QueryRouter::parse(&simple_query("SELECT * FROM data WHERE id = 5")).unwrap()) + .is_ok()); assert_eq!(qr.shard(), 2); - assert!(qr.infer(&simple_query( - "SELECT one, two, three FROM public.data WHERE id = 6" - ))); + assert!(qr + .infer( + &QueryRouter::parse(&simple_query( + "SELECT one, two, three FROM public.data WHERE id = 6" + )) + .unwrap() + ) + .is_ok()); assert_eq!(qr.shard(), 0); - assert!(qr.infer(&simple_query( - "SELECT * FROM data + assert!(qr + .infer( + &QueryRouter::parse(&simple_query( + "SELECT * FROM data INNER JOIN t2 ON data.id = 5 AND t2.data_id = data.id WHERE data.id = 5" - ))); + )) + .unwrap() + ) + .is_ok()); assert_eq!(qr.shard(), 2); // Shard did not move because we couldn't determine the sharding key since it could be ambiguous // in the query. - assert!(qr.infer(&simple_query( - "SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id" - ))); + assert!(qr + .infer( + &QueryRouter::parse(&simple_query( + "SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id" + )) + .unwrap() + ) + .is_ok()); assert_eq!(qr.shard(), 2); - assert!(qr.infer(&simple_query( - r#"SELECT * FROM "public"."data" WHERE "id" = 6"# - ))); + assert!(qr + .infer( + &QueryRouter::parse(&simple_query( + r#"SELECT * FROM "public"."data" WHERE "id" = 6"# + )) + .unwrap() + ) + .is_ok()); assert_eq!(qr.shard(), 0); - assert!(qr.infer(&simple_query( - r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"# - ))); + assert!(qr + .infer( + &QueryRouter::parse(&simple_query( + r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"# + )) + .unwrap() + ) + .is_ok()); assert_eq!(qr.shard(), 2); // Super unique sharding key qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string()); - assert!(qr.infer(&simple_query( - "SELECT * FROM table_x WHERE unique_enough_column_name = 6" - ))); + assert!(qr + .infer( + &QueryRouter::parse(&simple_query( + "SELECT * FROM table_x WHERE unique_enough_column_name = 6" + )) + .unwrap() + ) + .is_ok()); assert_eq!(qr.shard(), 0); - assert!(qr.infer(&simple_query("SELECT * FROM table_y WHERE another_key = 5"))); + assert!(qr + .infer( + &QueryRouter::parse(&simple_query("SELECT * FROM table_y WHERE another_key = 5")) + .unwrap() + ) + .is_ok()); assert_eq!(qr.shard(), 0); } @@ -1272,11 +1321,21 @@ mod test { qr.pool_settings.automatic_sharding_key = Some("data.id".to_string()); qr.pool_settings.shards = 3; - assert!(qr.infer(&simple_query(stmt))); + assert!(qr + .infer(&QueryRouter::parse(&simple_query(stmt)).unwrap()) + .is_ok()); assert_eq!(qr.placeholders.len(), 1); assert!(qr.infer_shard_from_bind(&bind)); assert_eq!(qr.shard(), 2); assert!(qr.placeholders.is_empty()); } + + #[test] + fn test_parse() { + let query = simple_query("SELECT * FROM pg_database"); + let ast = QueryRouter::parse(&query); + + assert!(ast.is_ok()); + } } From 54e986575d9c962410aa13df677116ac1d2a091c Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Tue, 2 May 2023 15:04:32 -0700 Subject: [PATCH 02/11] Plugins!! --- Cargo.lock | 21 ++++++++++++ Cargo.toml | 3 +- pgcat.toml | 3 ++ src/admin.rs | 2 +- src/client.rs | 83 +++++++++++++++++++++++++++++++++++++++++++++ src/config.rs | 3 ++ src/errors.rs | 4 ++- src/lib.rs | 5 +++ src/main.rs | 54 ++++++++++------------------- src/messages.rs | 35 +++++++++++++++++++ src/pool.rs | 14 ++++++++ src/query_router.rs | 63 +++++++++++++++++++++++++++------- 12 files changed, 239 insertions(+), 51 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6a3e0011..6724a92c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -913,6 +913,7 @@ dependencies = [ "rustls-pemfile", "serde", "serde_derive", + "serde_json", "sha-1", "sha2", "socket2", @@ -1174,6 +1175,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "ryu" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" + [[package]] name = "scopeguard" version = "1.1.0" @@ -1201,6 +1208,9 @@ name = "serde" version = "1.0.160" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bb2f3770c8bce3bcda7e149193a069a0f4365bda1fa5cd88e03bca26afc1216c" +dependencies = [ + "serde_derive", +] [[package]] name = "serde_derive" @@ -1213,6 +1223,17 @@ dependencies = [ "syn 2.0.9", ] +[[package]] +name = "serde_json" +version = "1.0.96" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" +dependencies = [ + "itoa", + "ryu", + "serde", +] + [[package]] name = "serde_spanned" version = "0.6.1" diff --git a/Cargo.toml b/Cargo.toml index 0d5c6446..80549821 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ rand = "0.8" chrono = "0.4" sha-1 = "0.10" toml = "0.7" -serde = "1" +serde = { version = "1", features = ["derive"] } serde_derive = "1" regex = "1" num_cpus = "1" @@ -44,6 +44,7 @@ webpki-roots = "0.23" rustls = { version = "0.21", features = ["dangerous_configuration"] } trust-dns-resolver = "0.22.0" tokio-test = "0.4.2" +serde_json = "1" [target.'cfg(not(target_env = "msvc"))'.dependencies] jemallocator = "0.5.0" diff --git a/pgcat.toml b/pgcat.toml index c844ce1f..45e25382 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -77,6 +77,9 @@ admin_username = "admin_user" # Password to access the virtual administrative database admin_password = "admin_pass" +# Plugins!! +# plugins = ["pg_table_access", "intercept"] + # pool configs are structured as pool. # the pool_name is what clients use as database name when connecting. # For a pool named `sharded_db`, clients access that pool using connection string like diff --git a/src/admin.rs b/src/admin.rs index 03af755c..ceba20c8 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -12,9 +12,9 @@ use tokio::time::Instant; use crate::config::{get_config, reload_config, VERSION}; use crate::errors::Error; use crate::messages::*; +use crate::pool::ClientServerMap; use crate::pool::{get_all_pools, get_pool}; use crate::stats::{get_client_stats, get_pool_stats, get_server_stats, ClientState, ServerState}; -use crate::ClientServerMap; pub fn generate_server_info_for_admin() -> BytesMut { let mut server_info = BytesMut::new(); diff --git a/src/client.rs b/src/client.rs index de2213a7..7a53eeb6 100644 --- a/src/client.rs +++ b/src/client.rs @@ -16,6 +16,7 @@ use crate::auth_passthrough::refetch_auth_hash; use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode}; use crate::constants::*; use crate::messages::*; +use crate::plugins::PluginOutput; use crate::pool::{get_pool, ClientServerMap, ConnectionPool}; use crate::query_router::{Command, QueryRouter}; use crate::server::Server; @@ -765,6 +766,9 @@ where self.stats.register(self.stats.clone()); + // Error returned by one of the plugins. + let mut plugin_output = None; + // Our custom protocol loop. // We expect the client to either start a transaction with regular queries // or issue commands for our sharding and server selection protocol. @@ -816,6 +820,22 @@ where 'Q' => { if query_router.query_parser_enabled() { if let Ok(ast) = QueryRouter::parse(&message) { + let plugin_result = query_router.execute_plugins(&ast).await; + + match plugin_result { + Ok(PluginOutput::Deny(error)) => { + error_response(&mut self.write, &error).await?; + continue; + } + + Ok(PluginOutput::Intercept(result)) => { + write_all(&mut self.write, result).await?; + continue; + } + + _ => (), + }; + let _ = query_router.infer(&ast); } } @@ -826,6 +846,10 @@ where if query_router.query_parser_enabled() { if let Ok(ast) = QueryRouter::parse(&message) { + if let Ok(output) = query_router.execute_plugins(&ast).await { + plugin_output = Some(output); + } + let _ = query_router.infer(&ast); } } @@ -861,6 +885,18 @@ where continue; } + // Check on plugin results. + match plugin_output { + Some(PluginOutput::Deny(error)) => { + self.buffer.clear(); + error_response(&mut self.write, &error).await?; + plugin_output = None; + continue; + } + + _ => (), + }; + // Get a pool instance referenced by the most up-to-date // pointer. This ensures we always read the latest config // when starting a query. @@ -1089,6 +1125,27 @@ where match code { // Query 'Q' => { + if query_router.query_parser_enabled() { + if let Ok(ast) = QueryRouter::parse(&message) { + let plugin_result = query_router.execute_plugins(&ast).await; + + match plugin_result { + Ok(PluginOutput::Deny(error)) => { + error_response(&mut self.write, &error).await?; + continue; + } + + Ok(PluginOutput::Intercept(result)) => { + write_all(&mut self.write, result).await?; + continue; + } + + _ => (), + }; + + let _ = query_router.infer(&ast); + } + } debug!("Sending query to server"); self.send_and_receive_loop( @@ -1128,6 +1185,14 @@ where // Parse // The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`. 'P' => { + if query_router.query_parser_enabled() { + if let Ok(ast) = QueryRouter::parse(&message) { + if let Ok(output) = query_router.execute_plugins(&ast).await { + plugin_output = Some(output); + } + } + } + self.buffer.put(&message[..]); } @@ -1159,6 +1224,24 @@ where 'S' => { debug!("Sending query to server"); + match plugin_output { + Some(PluginOutput::Deny(error)) => { + error_response(&mut self.write, &error).await?; + plugin_output = None; + self.buffer.clear(); + continue; + } + + Some(PluginOutput::Intercept(result)) => { + write_all(&mut self.write, result).await?; + plugin_output = None; + self.buffer.clear(); + continue; + } + + _ => (), + }; + self.buffer.put(&message[..]); let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char; diff --git a/src/config.rs b/src/config.rs index 9d66e48d..aa98421c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -302,6 +302,8 @@ pub struct General { pub auth_query: Option, pub auth_query_user: Option, pub auth_query_password: Option, + + pub query_router_plugins: Option>, } impl General { @@ -402,6 +404,7 @@ impl Default for General { auth_query_user: None, auth_query_password: None, server_lifetime: 1000 * 3600 * 24, // 24 hours, + query_router_plugins: None, } } } diff --git a/src/errors.rs b/src/errors.rs index f5fe21c3..3dfc86fb 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,7 +1,7 @@ //! Errors. /// Various errors. -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone)] pub enum Error { SocketError(String), ClientSocketError(String, ClientIdentifier), @@ -26,7 +26,9 @@ pub enum Error { AuthPassthroughError(String), UnsupportedStatement, QueryRouterParserError(String), + PermissionDenied(String), PermissionDeniedTable(String), + QueryDenied(String), } #[derive(Clone, PartialEq, Debug)] diff --git a/src/lib.rs b/src/lib.rs index 3a58bb38..db6167db 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,6 @@ +pub mod admin; pub mod auth_passthrough; +pub mod client; pub mod config; pub mod constants; pub mod dns_cache; @@ -6,7 +8,10 @@ pub mod errors; pub mod messages; pub mod mirrors; pub mod multi_logger; +pub mod plugins; pub mod pool; +pub mod prometheus; +pub mod query_router; pub mod scram; pub mod server; pub mod sharding; diff --git a/src/main.rs b/src/main.rs index dc48dd58..fe513171 100644 --- a/src/main.rs +++ b/src/main.rs @@ -61,37 +61,19 @@ use std::str::FromStr; use std::sync::Arc; use tokio::sync::broadcast; -mod admin; -mod auth_passthrough; -mod client; -mod config; -mod constants; -mod dns_cache; -mod errors; -mod messages; -mod mirrors; -mod multi_logger; -mod pool; -mod prometheus; -mod query_router; -mod scram; -mod server; -mod sharding; -mod stats; -mod tls; - -use crate::config::{get_config, reload_config, VERSION}; -use crate::messages::configure_socket; -use crate::pool::{ClientServerMap, ConnectionPool}; -use crate::prometheus::start_metric_server; -use crate::stats::{Collector, Reporter, REPORTER}; +use pgcat::config::{get_config, reload_config, VERSION}; +use pgcat::messages::configure_socket; +use pgcat::pool::{ClientServerMap, ConnectionPool}; +use pgcat::prometheus::start_metric_server; +use pgcat::stats::{Collector, Reporter, REPORTER}; +use pgcat::dns_cache; fn main() -> Result<(), Box> { - multi_logger::MultiLogger::init().unwrap(); + pgcat::multi_logger::MultiLogger::init().unwrap(); info!("Welcome to PgCat! Meow. (Version {})", VERSION); - if !query_router::QueryRouter::setup() { + if !pgcat::query_router::QueryRouter::setup() { error!("Could not setup query router"); std::process::exit(exitcode::CONFIG); } @@ -109,7 +91,7 @@ fn main() -> Result<(), Box> { let runtime = Builder::new_multi_thread().worker_threads(1).build()?; runtime.block_on(async { - match config::parse(&config_file).await { + match pgcat::config::parse(&config_file).await { Ok(_) => (), Err(err) => { error!("Config parse error: {:?}", err); @@ -168,14 +150,14 @@ fn main() -> Result<(), Box> { // Statistics reporting. REPORTER.store(Arc::new(Reporter::default())); - // Starts (if enabled) dns cache before pools initialization - match dns_cache::CachedResolver::from_config().await { - Ok(_) => (), - Err(err) => error!("DNS cache initialization error: {:?}", err), - }; + // Starts (if enabled) dns cache before pools initialization + match dns_cache::CachedResolver::from_config().await { + Ok(_) => (), + Err(err) => error!("DNS cache initialization error: {:?}", err), + }; - // Connection pool that allows to query all shards and replicas. - match ConnectionPool::from_config(client_server_map.clone()).await { + // Connection pool that allows to query all shards and replicas. + match ConnectionPool::from_config(client_server_map.clone()).await { Ok(_) => (), Err(err) => { error!("Pool error: {:?}", err); @@ -303,7 +285,7 @@ fn main() -> Result<(), Box> { tokio::task::spawn(async move { let start = chrono::offset::Utc::now().naive_utc(); - match client::client_entrypoint( + match pgcat::client::client_entrypoint( socket, client_server_map, shutdown_rx, @@ -334,7 +316,7 @@ fn main() -> Result<(), Box> { Err(err) => { match err { - errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err), + pgcat::errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err), _ => warn!("Client disconnected with error {:?}", err), } diff --git a/src/messages.rs b/src/messages.rs index 0e980fe6..ee4886df 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -20,6 +20,10 @@ pub enum DataType { Text, Int4, Numeric, + Bool, + Oid, + AnyArray, + Any, } impl From<&DataType> for i32 { @@ -28,6 +32,10 @@ impl From<&DataType> for i32 { DataType::Text => 25, DataType::Int4 => 23, DataType::Numeric => 1700, + DataType::Bool => 16, + DataType::Oid => 26, + DataType::AnyArray => 2277, + DataType::Any => 2276, } } } @@ -443,6 +451,10 @@ pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut { DataType::Text => -1, DataType::Int4 => 4, DataType::Numeric => -1, + DataType::Bool => 1, + DataType::Oid => 4, + DataType::AnyArray => -1, + DataType::Any => -1, }; row_desc.put_i16(type_size); @@ -481,6 +493,29 @@ pub fn data_row(row: &Vec) -> BytesMut { res } +pub fn data_row_nullable(row: &Vec>) -> BytesMut { + let mut res = BytesMut::new(); + let mut data_row = BytesMut::new(); + + data_row.put_i16(row.len() as i16); + + for column in row { + if let Some(column) = column { + let column = column.as_bytes(); + data_row.put_i32(column.len() as i32); + data_row.put_slice(column); + } else { + data_row.put_i32(-1 as i32); + } + } + + res.put_u8(b'D'); + res.put_i32(data_row.len() as i32 + 4); + res.put(data_row); + + res +} + /// Create a CommandComplete message. pub fn command_complete(command: &str) -> BytesMut { let cmd = BytesMut::from(format!("{}\0", command).as_bytes()); diff --git a/src/pool.rs b/src/pool.rs index 4a735e0f..2eed3cdd 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -91,6 +91,7 @@ pub struct PoolSettings { // Connecting user. pub user: User, + pub db: String, // Default server role to connect to. pub default_role: Option, @@ -129,6 +130,8 @@ pub struct PoolSettings { pub auth_query: Option, pub auth_query_user: Option, pub auth_query_password: Option, + + pub plugins: Option>, } impl Default for PoolSettings { @@ -138,6 +141,7 @@ impl Default for PoolSettings { load_balancing_mode: LoadBalancingMode::Random, shards: 1, user: User::default(), + db: String::default(), default_role: None, query_parser_enabled: false, primary_reads_enabled: true, @@ -152,6 +156,7 @@ impl Default for PoolSettings { auth_query: None, auth_query_user: None, auth_query_password: None, + plugins: None, } } } @@ -414,6 +419,7 @@ impl ConnectionPool { // shards: pool_config.shards.clone(), shards: shard_ids.len(), user: user.clone(), + db: pool_name.clone(), default_role: match pool_config.default_role.as_str() { "any" => None, "replica" => Some(Role::Replica), @@ -439,6 +445,7 @@ impl ConnectionPool { auth_query: pool_config.auth_query.clone(), auth_query_user: pool_config.auth_query_user.clone(), auth_query_password: pool_config.auth_query_password.clone(), + plugins: config.general.query_router_plugins.clone(), }, validated: Arc::new(AtomicBool::new(false)), paused: Arc::new(AtomicBool::new(false)), @@ -458,6 +465,13 @@ impl ConnectionPool { } } + // Initialize plugins here if required. + if let Some(plugins) = config.general.query_router_plugins { + if plugins.contains(&String::from("intercept")) { + crate::plugins::intercept::configure(&new_pools); + } + } + POOLS.store(Arc::new(new_pools.clone())); Ok(()) } diff --git a/src/query_router.rs b/src/query_router.rs index c3177e4e..99c8728d 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -6,7 +6,8 @@ use once_cell::sync::OnceCell; use regex::{Regex, RegexSet}; use sqlparser::ast::Statement::{Query, StartTransaction}; use sqlparser::ast::{ - BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, TableFactor, Value, + BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement, TableFactor, + Value, }; use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::Parser; @@ -14,6 +15,7 @@ use sqlparser::parser::Parser; use crate::config::Role; use crate::errors::Error; use crate::messages::BytesMutReader; +use crate::plugins::{Intercept, Plugin, PluginOutput, TableAccess}; use crate::pool::PoolSettings; use crate::sharding::Sharder; @@ -130,6 +132,10 @@ impl QueryRouter { self.pool_settings = pool_settings; } + pub fn pool_settings<'a>(&'a self) -> &'a PoolSettings { + &self.pool_settings + } + /// Try to parse a command and execute it. pub fn try_execute_command(&mut self, message_buffer: &BytesMut) -> Option<(Command, String)> { let mut message_cursor = Cursor::new(message_buffer); @@ -335,7 +341,7 @@ impl QueryRouter { // Query 'Q' => { let query = message_cursor.read_string().unwrap(); - debug!("Query: '{}'", query); + error!("Query: '{}'", query); query } @@ -347,7 +353,7 @@ impl QueryRouter { // Reads query string let query = message_cursor.read_string().unwrap(); - debug!("Prepared statement: '{}'", query); + error!("Prepared statement: '{}'", query); query } @@ -355,11 +361,7 @@ impl QueryRouter { }; match Parser::parse_sql(&PostgreSqlDialect {}, &query) { - Ok(ast) => { - debug!("AST: {:?}", ast); - Ok(ast) - } - + Ok(ast) => Ok(ast), Err(err) => { debug!("{}: {}", err, query); Err(Error::QueryRouterParserError(err.to_string())) @@ -786,6 +788,32 @@ impl QueryRouter { } } + /// Add your plugins here and execute them. + pub async fn execute_plugins(&self, ast: &Vec) -> Result { + if let Some(plugins) = &self.pool_settings.plugins { + if plugins.contains(&String::from("intercept")) { + let mut intercept = Intercept {}; + let result = intercept.run(&self, ast).await; + + if let Ok(PluginOutput::Intercept(output)) = result { + return Ok(PluginOutput::Intercept(output)); + } + } + + if plugins.contains(&String::from("pg_table_access")) { + let mut table_access = TableAccess { + forbidden_tables: vec![String::from("pg_database"), String::from("pg_roles")], + }; + + if let Ok(PluginOutput::Deny(error)) = table_access.run(&self, ast).await { + return Ok(PluginOutput::Deny(error)); + } + } + } + + Ok(PluginOutput::Allow) + } + fn set_sharding_key(&mut self, sharding_key: i64) -> Option { let sharder = Sharder::new( self.pool_settings.shards, @@ -813,11 +841,22 @@ impl QueryRouter { /// Should we attempt to parse queries? pub fn query_parser_enabled(&self) -> bool { let enabled = match self.query_parser_enabled { - None => self.pool_settings.query_parser_enabled, - Some(value) => value, - }; + None => { + debug!( + "Using pool settings, query_parser_enabled: {}", + self.pool_settings.query_parser_enabled + ); + self.pool_settings.query_parser_enabled + } - debug!("Query parser enabled: {}", enabled); + Some(value) => { + debug!( + "Using query parser override, query_parser_enabled: {}", + value + ); + value + } + }; enabled } From beb0586fca657520252a04855cf43cc73399af22 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Tue, 2 May 2023 15:10:22 -0700 Subject: [PATCH 03/11] cleanup --- src/main.rs | 2 +- src/pool.rs | 2 -- src/query_router.rs | 4 ++-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/main.rs b/src/main.rs index fe513171..6af4db95 100644 --- a/src/main.rs +++ b/src/main.rs @@ -62,11 +62,11 @@ use std::sync::Arc; use tokio::sync::broadcast; use pgcat::config::{get_config, reload_config, VERSION}; +use pgcat::dns_cache; use pgcat::messages::configure_socket; use pgcat::pool::{ClientServerMap, ConnectionPool}; use pgcat::prometheus::start_metric_server; use pgcat::stats::{Collector, Reporter, REPORTER}; -use pgcat::dns_cache; fn main() -> Result<(), Box> { pgcat::multi_logger::MultiLogger::init().unwrap(); diff --git a/src/pool.rs b/src/pool.rs index 2eed3cdd..158ff559 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -400,8 +400,6 @@ impl ConnectionPool { ); } - debug!("Query router: {}", pool_config.query_parser_enabled); - let pool = ConnectionPool { databases: shards, stats: pool_stats, diff --git a/src/query_router.rs b/src/query_router.rs index 99c8728d..820119a0 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -341,7 +341,7 @@ impl QueryRouter { // Query 'Q' => { let query = message_cursor.read_string().unwrap(); - error!("Query: '{}'", query); + debug!("Query: '{}'", query); query } @@ -353,7 +353,7 @@ impl QueryRouter { // Reads query string let query = message_cursor.read_string().unwrap(); - error!("Prepared statement: '{}'", query); + debug!("Prepared statement: '{}'", query); query } From ba2a1cc79c8b5dfbc367342fab78cbced3ddd8df Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Tue, 2 May 2023 15:11:22 -0700 Subject: [PATCH 04/11] actual names --- pgcat.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgcat.toml b/pgcat.toml index 45e25382..dfb57822 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -78,7 +78,7 @@ admin_username = "admin_user" admin_password = "admin_pass" # Plugins!! -# plugins = ["pg_table_access", "intercept"] +# query_router_plugins = ["pg_table_access", "intercept"] # pool configs are structured as pool. # the pool_name is what clients use as database name when connecting. From 57f870c7246292d5d63a53ab25f42ea3953a8265 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Tue, 2 May 2023 15:15:00 -0700 Subject: [PATCH 05/11] the actual plugins --- src/plugins/intercept.rs | 263 ++++++++++++++++++++++++++++++++++++ src/plugins/mod.rs | 40 ++++++ src/plugins/table_access.rs | 50 +++++++ 3 files changed, 353 insertions(+) create mode 100644 src/plugins/intercept.rs create mode 100644 src/plugins/mod.rs create mode 100644 src/plugins/table_access.rs diff --git a/src/plugins/intercept.rs b/src/plugins/intercept.rs new file mode 100644 index 00000000..0adbb177 --- /dev/null +++ b/src/plugins/intercept.rs @@ -0,0 +1,263 @@ +//! The intercept plugin. +//! +//! It intercepts queries and returns fake results. + +use arc_swap::ArcSwap; +use async_trait::async_trait; +use bytes::{BufMut, BytesMut}; +use once_cell::sync::Lazy; +use serde_json::{json, Value}; +use sqlparser::ast::Statement; +use std::collections::HashMap; + +use log::debug; +use std::sync::Arc; + +use crate::{ + errors::Error, + messages::{command_complete, data_row_nullable, row_description, DataType}, + plugins::{Plugin, PluginOutput}, + pool::{PoolIdentifier, PoolMap}, + query_router::QueryRouter, +}; + +pub static CONFIG: Lazy>> = + Lazy::new(|| ArcSwap::from_pointee(HashMap::new())); + +/// Configure the intercept plugin. +pub fn configure(pools: &PoolMap) { + let mut config = HashMap::new(); + for (identifier, _) in pools.iter() { + // TODO: make this configurable from a text config. + let value = fool_datagrip(&identifier.db, &identifier.user); + config.insert(identifier.clone(), value); + } + + CONFIG.store(Arc::new(config)); +} + +/// The intercept plugin. +pub struct Intercept; + +#[async_trait] +impl Plugin for Intercept { + async fn run( + &mut self, + query_router: &QueryRouter, + ast: &Vec, + ) -> Result { + if ast.is_empty() { + return Ok(PluginOutput::Allow); + } + + let mut result = BytesMut::new(); + let query_map = match CONFIG.load().get(&PoolIdentifier::new( + &query_router.pool_settings().db, + &query_router.pool_settings().user.username, + )) { + Some(query_map) => query_map.clone(), + None => return Ok(PluginOutput::Allow), + }; + + for q in ast { + // Normalization + let q = q.to_string().to_ascii_lowercase(); + + for target in query_map.as_array().unwrap().iter() { + if target["query"].as_str().unwrap() == q { + debug!("Query matched: {}", q); + + let rd = target["schema"] + .as_array() + .unwrap() + .iter() + .map(|row| { + let row = row.as_object().unwrap(); + ( + row["name"].as_str().unwrap(), + match row["data_type"].as_str().unwrap() { + "text" => DataType::Text, + "anyarray" => DataType::AnyArray, + "oid" => DataType::Oid, + "bool" => DataType::Bool, + "int4" => DataType::Int4, + _ => DataType::Any, + }, + ) + }) + .collect::>(); + + result.put(row_description(&rd)); + + target["result"].as_array().unwrap().iter().for_each(|row| { + let row = row + .as_array() + .unwrap() + .iter() + .map(|s| { + let s = s.as_str().unwrap().to_string(); + + if s == "" { + None + } else { + Some(s) + } + }) + .collect::>>(); + result.put(data_row_nullable(&row)); + }); + + result.put(command_complete("SELECT")); + } + } + } + + if !result.is_empty() { + result.put_u8(b'Z'); + result.put_i32(5); + result.put_u8(b'I'); + + return Ok(PluginOutput::Intercept(result)); + } else { + Ok(PluginOutput::Allow) + } + } +} + +/// Make IntelliJ SQL plugin believe it's talking to an actual database +/// instead of PgCat. +fn fool_datagrip(database: &str, user: &str) -> Value { + json!([ + { + "query": "select current_database() as a, current_schemas(false) as b", + "schema": [ + { + "name": "a", + "data_type": "text", + }, + { + "name": "b", + "data_type": "anyarray", + }, + ], + + "result": [ + [database, "{public}"], + ], + }, + { + "query": "select current_database(), current_schema(), current_user", + "schema": [ + { + "name": "current_database", + "data_type": "text", + }, + { + "name": "current_schema", + "data_type": "text", + }, + { + "name": "current_user", + "data_type": "text", + } + ], + + "result": [ + ["sharded_db", "public", "sharding_user"], + ], + }, + { + "query": "select cast(n.oid as bigint) as id, datname as name, d.description, datistemplate as is_template, datallowconn as allow_connections, pg_catalog.pg_get_userbyid(n.datdba) as \"owner\" from pg_catalog.pg_database as n left join pg_catalog.pg_shdescription as d on n.oid = d.objoid order by case when datname = pg_catalog.current_database() then -cast(1 as bigint) else cast(n.oid as bigint) end", + "schema": [ + { + "name": "id", + "data_type": "oid", + }, + { + "name": "name", + "data_type": "text", + }, + { + "name": "description", + "data_type": "text", + }, + { + "name": "is_template", + "data_type": "bool", + }, + { + "name": "allow_connections", + "data_type": "bool", + }, + { + "name": "owner", + "data_type": "text", + } + ], + "result": [ + ["16387", database, "", "f", "t", user], + ] + }, + { + "query": "select cast(r.oid as bigint) as role_id, rolname as role_name, rolsuper as is_super, rolinherit as is_inherit, rolcreaterole as can_createrole, rolcreatedb as can_createdb, rolcanlogin as can_login, rolreplication as is_replication, rolconnlimit as conn_limit, rolvaliduntil as valid_until, rolbypassrls as bypass_rls, rolconfig as config, d.description from pg_catalog.pg_roles as r left join pg_catalog.pg_shdescription as d on d.objoid = r.oid", + "schema": [ + { + "name": "role_id", + "data_type": "oid", + }, + { + "name": "role_name", + "data_type": "text", + }, + { + "name": "is_super", + "data_type": "bool", + }, + { + "name": "is_inherit", + "data_type": "bool", + }, + { + "name": "can_createrole", + "data_type": "bool", + }, + { + "name": "can_createdb", + "data_type": "bool", + }, + { + "name": "can_login", + "data_type": "bool", + }, + { + "name": "is_replication", + "data_type": "bool", + }, + { + "name": "conn_limit", + "data_type": "int4", + }, + { + "name": "valid_until", + "data_type": "text", + }, + { + "name": "bypass_rls", + "data_type": "bool", + }, + { + "name": "config", + "data_type": "text", + }, + { + "name": "description", + "data_type": "text", + }, + ], + "result": [ + ["10", "postgres", "f", "t", "f", "f", "t", "f", "-1", "", "f", "", ""], + ["16419", user, "f", "t", "f", "f", "t", "f", "-1", "", "f", "", ""], + ] + } + ]) +} diff --git a/src/plugins/mod.rs b/src/plugins/mod.rs new file mode 100644 index 00000000..326859c2 --- /dev/null +++ b/src/plugins/mod.rs @@ -0,0 +1,40 @@ +//! The plugin ecosystem. +//! +//! Currently plugins only grant access or deny access to the database for a particual query. +//! Example use cases: +//! - block known bad queries +//! - block access to system catalogs +//! - block dangerous modifications like `DROP TABLE` +//! - etc +//! + +pub mod intercept; +pub mod table_access; + +use crate::{errors::Error, query_router::QueryRouter}; +use async_trait::async_trait; +use bytes::BytesMut; +use sqlparser::ast::Statement; + +pub use intercept::Intercept; +pub use table_access::TableAccess; + +#[derive(Clone)] +pub enum PluginOutput { + Allow, + Deny(String), + Overwrite(Vec), + Intercept(BytesMut), +} + +#[async_trait] +pub trait Plugin { + // Custom output is allowed because we want to extend this system + // to rewriting queries some day. So an output of a plugin could be + // a rewritten AST. + async fn run( + &mut self, + query_router: &QueryRouter, + ast: &Vec, + ) -> Result; +} diff --git a/src/plugins/table_access.rs b/src/plugins/table_access.rs new file mode 100644 index 00000000..2e23278a --- /dev/null +++ b/src/plugins/table_access.rs @@ -0,0 +1,50 @@ +//! This query router plugin will check if the user can access a particular +//! table as part of their query. If they can't, the query will not be routed. + +use async_trait::async_trait; +use sqlparser::ast::{visit_relations, Statement}; + +use crate::{ + errors::Error, + plugins::{Plugin, PluginOutput}, + query_router::QueryRouter, +}; + +use core::ops::ControlFlow; + +pub struct TableAccess { + pub forbidden_tables: Vec, +} + +#[async_trait] +impl Plugin for TableAccess { + async fn run( + &mut self, + _query_router: &QueryRouter, + ast: &Vec, + ) -> Result { + let mut found = None; + + visit_relations(ast, |relation| { + let relation = relation.to_string(); + let parts = relation.split(".").collect::>(); + let table_name = parts.last().unwrap(); + + if self.forbidden_tables.contains(&table_name.to_string()) { + found = Some(table_name.to_string()); + ControlFlow::<()>::Break(()) + } else { + ControlFlow::<()>::Continue(()) + } + }); + + if let Some(found) = found { + Ok(PluginOutput::Deny(format!( + "permission for table \"{}\" denied", + found + ))) + } else { + Ok(PluginOutput::Allow) + } + } +} From b7a761918c8530dfec2cd770942a534ff98c52fa Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Tue, 2 May 2023 15:15:30 -0700 Subject: [PATCH 06/11] comment --- src/client.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/client.rs b/src/client.rs index 7a53eeb6..331a0da4 100644 --- a/src/client.rs +++ b/src/client.rs @@ -766,7 +766,7 @@ where self.stats.register(self.stats.clone()); - // Error returned by one of the plugins. + // Result returned by one of the plugins. let mut plugin_output = None; // Our custom protocol loop. From 5bedcfd59994c0713b05a73fe1bb2aa13f10043f Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Tue, 2 May 2023 15:22:19 -0700 Subject: [PATCH 07/11] fix tests --- src/query_router.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/query_router.rs b/src/query_router.rs index 820119a0..b5842b7f 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -1155,6 +1155,8 @@ mod test { auth_query: None, auth_query_password: None, auth_query_user: None, + db: "test".to_string(), + plugins: None, }; let mut qr = QueryRouter::new(); assert_eq!(qr.active_role, None); @@ -1228,6 +1230,8 @@ mod test { auth_query: None, auth_query_password: None, auth_query_user: None, + db: "test".to_string(), + plugins: None, }; let mut qr = QueryRouter::new(); qr.update_pool_settings(pool_settings.clone()); From 58718797a644acd18f4dab75a41bbac627e52792 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 3 May 2023 08:30:03 -0700 Subject: [PATCH 08/11] Tests --- src/plugins/intercept.rs | 15 +++++++++++++++ src/plugins/mod.rs | 2 +- src/query_router.rs | 23 +++++++++++++++++++---- tests/ruby/helpers/pgcat_helper.rb | 1 + tests/ruby/plugins_spec.rb | 14 ++++++++++++++ 5 files changed, 50 insertions(+), 5 deletions(-) create mode 100644 tests/ruby/plugins_spec.rb diff --git a/src/plugins/intercept.rs b/src/plugins/intercept.rs index 0adbb177..6e250dca 100644 --- a/src/plugins/intercept.rs +++ b/src/plugins/intercept.rs @@ -6,6 +6,7 @@ use arc_swap::ArcSwap; use async_trait::async_trait; use bytes::{BufMut, BytesMut}; use once_cell::sync::Lazy; +use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use sqlparser::ast::Statement; use std::collections::HashMap; @@ -36,6 +37,20 @@ pub fn configure(pools: &PoolMap) { CONFIG.store(Arc::new(config)); } +// TODO: use these structs for deserialization +#[derive(Serialize, Deserialize)] +pub struct Rule { + query: String, + schema: Vec, + result: Vec>, +} + +#[derive(Serialize, Deserialize)] +pub struct Column { + name: String, + data_type: String, +} + /// The intercept plugin. pub struct Intercept; diff --git a/src/plugins/mod.rs b/src/plugins/mod.rs index 326859c2..92fa70b7 100644 --- a/src/plugins/mod.rs +++ b/src/plugins/mod.rs @@ -19,7 +19,7 @@ use sqlparser::ast::Statement; pub use intercept::Intercept; pub use table_access::TableAccess; -#[derive(Clone)] +#[derive(Clone, Debug, PartialEq)] pub enum PluginOutput { Allow, Deny(String), diff --git a/src/query_router.rs b/src/query_router.rs index b5842b7f..93bcd4f2 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -1374,11 +1374,26 @@ mod test { assert!(qr.placeholders.is_empty()); } - #[test] - fn test_parse() { + #[tokio::test] + async fn test_table_access_plugin() { + QueryRouter::setup(); + + let mut qr = QueryRouter::new(); + + let mut pool_settings = PoolSettings::default(); + pool_settings.plugins = Some(vec![String::from("pg_table_access")]); + qr.update_pool_settings(pool_settings); + let query = simple_query("SELECT * FROM pg_database"); - let ast = QueryRouter::parse(&query); + let ast = QueryRouter::parse(&query).unwrap(); + + let res = qr.execute_plugins(&ast).await; - assert!(ast.is_ok()); + assert_eq!( + res, + Ok(PluginOutput::Deny( + "permission for table \"pg_database\" denied".to_string() + )) + ); } } diff --git a/tests/ruby/helpers/pgcat_helper.rb b/tests/ruby/helpers/pgcat_helper.rb index ad4c32a4..eb0cdaa9 100644 --- a/tests/ruby/helpers/pgcat_helper.rb +++ b/tests/ruby/helpers/pgcat_helper.rb @@ -27,6 +27,7 @@ def self.three_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mod primary2 = PgInstance.new(8432, user["username"], user["password"], "shard2") pgcat_cfg = pgcat.current_config + pgcat_cfg["general"]["query_router_plugins"] = ["intercept"] pgcat_cfg["pools"] = { "#{pool_name}" => { "default_role" => "any", diff --git a/tests/ruby/plugins_spec.rb b/tests/ruby/plugins_spec.rb new file mode 100644 index 00000000..d4e233ab --- /dev/null +++ b/tests/ruby/plugins_spec.rb @@ -0,0 +1,14 @@ +require_relative 'spec_helper' + + +describe "Plugins" do + let(:processes) { Helpers::Pgcat.three_shard_setup("sharded_db", 5) } + + context "intercept" do + it "will intercept an intellij query" do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + res = conn.exec("select current_database() as a, current_schemas(false) as b") + expect(res.values).to eq([["sharded_db", "{public}"]]) + end + end +end From f7dea63fb2c92a7d819c84500dd318ffdcbf4f7a Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 3 May 2023 08:32:05 -0700 Subject: [PATCH 09/11] unused errors --- src/errors.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/errors.rs b/src/errors.rs index 3dfc86fb..b1796eee 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -26,9 +26,6 @@ pub enum Error { AuthPassthroughError(String), UnsupportedStatement, QueryRouterParserError(String), - PermissionDenied(String), - PermissionDeniedTable(String), - QueryDenied(String), } #[derive(Clone, PartialEq, Debug)] From ce8979617c680c2397c4eb4a05aec4c917c771cc Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 3 May 2023 08:49:57 -0700 Subject: [PATCH 10/11] Increase reaper rate to actually enforce settings --- src/pool.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/pool.rs b/src/pool.rs index 158ff559..344a4e38 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -373,12 +373,19 @@ impl ConnectionPool { }, }; + println!("\n\n\n\n"); + println!("Idle timeout({}): {}", pool_name, idle_timeout); + let pool = Pool::builder() .max_size(user.pool_size) .min_idle(user.min_pool_size) .connection_timeout(std::time::Duration::from_millis(connect_timeout)) .idle_timeout(Some(std::time::Duration::from_millis(idle_timeout))) .max_lifetime(Some(std::time::Duration::from_millis(server_lifetime))) + .reaper_rate(std::time::Duration::from_millis(std::cmp::min( + idle_timeout, + server_lifetime, + ))) .test_on_check_out(false) .build(manager) .await?; From 01dd479b5da541b5644ee7c825c618210f1a75ac Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 3 May 2023 08:55:32 -0700 Subject: [PATCH 11/11] ok --- src/pool.rs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/pool.rs b/src/pool.rs index 344a4e38..b986548a 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -61,6 +61,8 @@ pub struct PoolIdentifier { pub user: String, } +static POOL_REAPER_RATE: u64 = 30_000; // 30 seconds by default + impl PoolIdentifier { /// Create a new user/pool identifier. pub fn new(db: &str, user: &str) -> PoolIdentifier { @@ -373,8 +375,12 @@ impl ConnectionPool { }, }; - println!("\n\n\n\n"); - println!("Idle timeout({}): {}", pool_name, idle_timeout); + let reaper_rate = *vec![idle_timeout, server_lifetime, POOL_REAPER_RATE] + .iter() + .min() + .unwrap(); + + debug!("Pool reaper rate: {}ms", reaper_rate); let pool = Pool::builder() .max_size(user.pool_size) @@ -382,10 +388,7 @@ impl ConnectionPool { .connection_timeout(std::time::Duration::from_millis(connect_timeout)) .idle_timeout(Some(std::time::Duration::from_millis(idle_timeout))) .max_lifetime(Some(std::time::Duration::from_millis(server_lifetime))) - .reaper_rate(std::time::Duration::from_millis(std::cmp::min( - idle_timeout, - server_lifetime, - ))) + .reaper_rate(std::time::Duration::from_millis(reaper_rate)) .test_on_check_out(false) .build(manager) .await?;