diff --git a/Cargo.lock b/Cargo.lock index dad2e6c5..6724a92c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -882,7 +882,7 @@ checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" [[package]] name = "pgcat" -version = "1.0.1" +version = "1.0.2-alpha1" dependencies = [ "arc-swap", "async-trait", @@ -913,6 +913,7 @@ dependencies = [ "rustls-pemfile", "serde", "serde_derive", + "serde_json", "sha-1", "sha2", "socket2", @@ -1174,6 +1175,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "ryu" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" + [[package]] name = "scopeguard" version = "1.1.0" @@ -1201,6 +1208,9 @@ name = "serde" version = "1.0.160" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bb2f3770c8bce3bcda7e149193a069a0f4365bda1fa5cd88e03bca26afc1216c" +dependencies = [ + "serde_derive", +] [[package]] name = "serde_derive" @@ -1213,6 +1223,17 @@ dependencies = [ "syn 2.0.9", ] +[[package]] +name = "serde_json" +version = "1.0.96" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" +dependencies = [ + "itoa", + "ryu", + "serde", +] + [[package]] name = "serde_spanned" version = "0.6.1" @@ -1297,6 +1318,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "355dc4d4b6207ca8a3434fc587db0a8016130a574dbcdbfb93d7f7b5bc5b211a" dependencies = [ "log", + "sqlparser_derive", +] + +[[package]] +name = "sqlparser_derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55fe75cb4a364c7f7ae06c7dbbc8d84bddd85d6cdf9975963c3935bc1991761e" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 436c3dd5..80549821 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgcat" -version = "1.0.1" +version = "1.0.2-alpha1" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -14,12 +14,12 @@ rand = "0.8" chrono = "0.4" sha-1 = "0.10" toml = "0.7" -serde = "1" +serde = { version = "1", features = ["derive"] } serde_derive = "1" regex = "1" num_cpus = "1" once_cell = "1" -sqlparser = "0.33.0" +sqlparser = {version = "0.33", features = ["visitor"] } log = "0.4" arc-swap = "1" env_logger = "0.10" @@ -44,6 +44,7 @@ webpki-roots = "0.23" rustls = { version = "0.21", features = ["dangerous_configuration"] } trust-dns-resolver = "0.22.0" tokio-test = "0.4.2" +serde_json = "1" [target.'cfg(not(target_env = "msvc"))'.dependencies] jemallocator = "0.5.0" diff --git a/pgcat.toml b/pgcat.toml index c844ce1f..dfb57822 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -77,6 +77,9 @@ admin_username = "admin_user" # Password to access the virtual administrative database admin_password = "admin_pass" +# Plugins!! +# query_router_plugins = ["pg_table_access", "intercept"] + # pool configs are structured as pool. # the pool_name is what clients use as database name when connecting. # For a pool named `sharded_db`, clients access that pool using connection string like diff --git a/src/admin.rs b/src/admin.rs index 03af755c..ceba20c8 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -12,9 +12,9 @@ use tokio::time::Instant; use crate::config::{get_config, reload_config, VERSION}; use crate::errors::Error; use crate::messages::*; +use crate::pool::ClientServerMap; use crate::pool::{get_all_pools, get_pool}; use crate::stats::{get_client_stats, get_pool_stats, get_server_stats, ClientState, ServerState}; -use crate::ClientServerMap; pub fn generate_server_info_for_admin() -> BytesMut { let mut server_info = BytesMut::new(); diff --git a/src/client.rs b/src/client.rs index efde7554..331a0da4 100644 --- a/src/client.rs +++ b/src/client.rs @@ -16,6 +16,7 @@ use crate::auth_passthrough::refetch_auth_hash; use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode}; use crate::constants::*; use crate::messages::*; +use crate::plugins::PluginOutput; use crate::pool::{get_pool, ClientServerMap, ConnectionPool}; use crate::query_router::{Command, QueryRouter}; use crate::server::Server; @@ -765,6 +766,9 @@ where self.stats.register(self.stats.clone()); + // Result returned by one of the plugins. + let mut plugin_output = None; + // Our custom protocol loop. // We expect the client to either start a transaction with regular queries // or issue commands for our sharding and server selection protocol. @@ -815,7 +819,25 @@ where 'Q' => { if query_router.query_parser_enabled() { - query_router.infer(&message); + if let Ok(ast) = QueryRouter::parse(&message) { + let plugin_result = query_router.execute_plugins(&ast).await; + + match plugin_result { + Ok(PluginOutput::Deny(error)) => { + error_response(&mut self.write, &error).await?; + continue; + } + + Ok(PluginOutput::Intercept(result)) => { + write_all(&mut self.write, result).await?; + continue; + } + + _ => (), + }; + + let _ = query_router.infer(&ast); + } } } @@ -823,7 +845,13 @@ where self.buffer.put(&message[..]); if query_router.query_parser_enabled() { - query_router.infer(&message); + if let Ok(ast) = QueryRouter::parse(&message) { + if let Ok(output) = query_router.execute_plugins(&ast).await { + plugin_output = Some(output); + } + + let _ = query_router.infer(&ast); + } } continue; @@ -857,6 +885,18 @@ where continue; } + // Check on plugin results. + match plugin_output { + Some(PluginOutput::Deny(error)) => { + self.buffer.clear(); + error_response(&mut self.write, &error).await?; + plugin_output = None; + continue; + } + + _ => (), + }; + // Get a pool instance referenced by the most up-to-date // pointer. This ensures we always read the latest config // when starting a query. @@ -1085,6 +1125,27 @@ where match code { // Query 'Q' => { + if query_router.query_parser_enabled() { + if let Ok(ast) = QueryRouter::parse(&message) { + let plugin_result = query_router.execute_plugins(&ast).await; + + match plugin_result { + Ok(PluginOutput::Deny(error)) => { + error_response(&mut self.write, &error).await?; + continue; + } + + Ok(PluginOutput::Intercept(result)) => { + write_all(&mut self.write, result).await?; + continue; + } + + _ => (), + }; + + let _ = query_router.infer(&ast); + } + } debug!("Sending query to server"); self.send_and_receive_loop( @@ -1124,6 +1185,14 @@ where // Parse // The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`. 'P' => { + if query_router.query_parser_enabled() { + if let Ok(ast) = QueryRouter::parse(&message) { + if let Ok(output) = query_router.execute_plugins(&ast).await { + plugin_output = Some(output); + } + } + } + self.buffer.put(&message[..]); } @@ -1155,6 +1224,24 @@ where 'S' => { debug!("Sending query to server"); + match plugin_output { + Some(PluginOutput::Deny(error)) => { + error_response(&mut self.write, &error).await?; + plugin_output = None; + self.buffer.clear(); + continue; + } + + Some(PluginOutput::Intercept(result)) => { + write_all(&mut self.write, result).await?; + plugin_output = None; + self.buffer.clear(); + continue; + } + + _ => (), + }; + self.buffer.put(&message[..]); let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char; diff --git a/src/config.rs b/src/config.rs index fd7d3912..aa98421c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -298,9 +298,12 @@ pub struct General { pub admin_username: String, pub admin_password: String, + // Support for auth query pub auth_query: Option, pub auth_query_user: Option, pub auth_query_password: Option, + + pub query_router_plugins: Option>, } impl General { @@ -401,6 +404,7 @@ impl Default for General { auth_query_user: None, auth_query_password: None, server_lifetime: 1000 * 3600 * 24, // 24 hours, + query_router_plugins: None, } } } diff --git a/src/errors.rs b/src/errors.rs index fb70c042..b1796eee 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,7 +1,7 @@ //! Errors. /// Various errors. -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone)] pub enum Error { SocketError(String), ClientSocketError(String, ClientIdentifier), @@ -24,6 +24,8 @@ pub enum Error { ParseBytesError(String), AuthError(String), AuthPassthroughError(String), + UnsupportedStatement, + QueryRouterParserError(String), } #[derive(Clone, PartialEq, Debug)] diff --git a/src/lib.rs b/src/lib.rs index 3a58bb38..db6167db 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,6 @@ +pub mod admin; pub mod auth_passthrough; +pub mod client; pub mod config; pub mod constants; pub mod dns_cache; @@ -6,7 +8,10 @@ pub mod errors; pub mod messages; pub mod mirrors; pub mod multi_logger; +pub mod plugins; pub mod pool; +pub mod prometheus; +pub mod query_router; pub mod scram; pub mod server; pub mod sharding; diff --git a/src/main.rs b/src/main.rs index dc48dd58..6af4db95 100644 --- a/src/main.rs +++ b/src/main.rs @@ -61,37 +61,19 @@ use std::str::FromStr; use std::sync::Arc; use tokio::sync::broadcast; -mod admin; -mod auth_passthrough; -mod client; -mod config; -mod constants; -mod dns_cache; -mod errors; -mod messages; -mod mirrors; -mod multi_logger; -mod pool; -mod prometheus; -mod query_router; -mod scram; -mod server; -mod sharding; -mod stats; -mod tls; - -use crate::config::{get_config, reload_config, VERSION}; -use crate::messages::configure_socket; -use crate::pool::{ClientServerMap, ConnectionPool}; -use crate::prometheus::start_metric_server; -use crate::stats::{Collector, Reporter, REPORTER}; +use pgcat::config::{get_config, reload_config, VERSION}; +use pgcat::dns_cache; +use pgcat::messages::configure_socket; +use pgcat::pool::{ClientServerMap, ConnectionPool}; +use pgcat::prometheus::start_metric_server; +use pgcat::stats::{Collector, Reporter, REPORTER}; fn main() -> Result<(), Box> { - multi_logger::MultiLogger::init().unwrap(); + pgcat::multi_logger::MultiLogger::init().unwrap(); info!("Welcome to PgCat! Meow. (Version {})", VERSION); - if !query_router::QueryRouter::setup() { + if !pgcat::query_router::QueryRouter::setup() { error!("Could not setup query router"); std::process::exit(exitcode::CONFIG); } @@ -109,7 +91,7 @@ fn main() -> Result<(), Box> { let runtime = Builder::new_multi_thread().worker_threads(1).build()?; runtime.block_on(async { - match config::parse(&config_file).await { + match pgcat::config::parse(&config_file).await { Ok(_) => (), Err(err) => { error!("Config parse error: {:?}", err); @@ -168,14 +150,14 @@ fn main() -> Result<(), Box> { // Statistics reporting. REPORTER.store(Arc::new(Reporter::default())); - // Starts (if enabled) dns cache before pools initialization - match dns_cache::CachedResolver::from_config().await { - Ok(_) => (), - Err(err) => error!("DNS cache initialization error: {:?}", err), - }; + // Starts (if enabled) dns cache before pools initialization + match dns_cache::CachedResolver::from_config().await { + Ok(_) => (), + Err(err) => error!("DNS cache initialization error: {:?}", err), + }; - // Connection pool that allows to query all shards and replicas. - match ConnectionPool::from_config(client_server_map.clone()).await { + // Connection pool that allows to query all shards and replicas. + match ConnectionPool::from_config(client_server_map.clone()).await { Ok(_) => (), Err(err) => { error!("Pool error: {:?}", err); @@ -303,7 +285,7 @@ fn main() -> Result<(), Box> { tokio::task::spawn(async move { let start = chrono::offset::Utc::now().naive_utc(); - match client::client_entrypoint( + match pgcat::client::client_entrypoint( socket, client_server_map, shutdown_rx, @@ -334,7 +316,7 @@ fn main() -> Result<(), Box> { Err(err) => { match err { - errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err), + pgcat::errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err), _ => warn!("Client disconnected with error {:?}", err), } diff --git a/src/messages.rs b/src/messages.rs index 0e980fe6..ee4886df 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -20,6 +20,10 @@ pub enum DataType { Text, Int4, Numeric, + Bool, + Oid, + AnyArray, + Any, } impl From<&DataType> for i32 { @@ -28,6 +32,10 @@ impl From<&DataType> for i32 { DataType::Text => 25, DataType::Int4 => 23, DataType::Numeric => 1700, + DataType::Bool => 16, + DataType::Oid => 26, + DataType::AnyArray => 2277, + DataType::Any => 2276, } } } @@ -443,6 +451,10 @@ pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut { DataType::Text => -1, DataType::Int4 => 4, DataType::Numeric => -1, + DataType::Bool => 1, + DataType::Oid => 4, + DataType::AnyArray => -1, + DataType::Any => -1, }; row_desc.put_i16(type_size); @@ -481,6 +493,29 @@ pub fn data_row(row: &Vec) -> BytesMut { res } +pub fn data_row_nullable(row: &Vec>) -> BytesMut { + let mut res = BytesMut::new(); + let mut data_row = BytesMut::new(); + + data_row.put_i16(row.len() as i16); + + for column in row { + if let Some(column) = column { + let column = column.as_bytes(); + data_row.put_i32(column.len() as i32); + data_row.put_slice(column); + } else { + data_row.put_i32(-1 as i32); + } + } + + res.put_u8(b'D'); + res.put_i32(data_row.len() as i32 + 4); + res.put(data_row); + + res +} + /// Create a CommandComplete message. pub fn command_complete(command: &str) -> BytesMut { let cmd = BytesMut::from(format!("{}\0", command).as_bytes()); diff --git a/src/plugins/intercept.rs b/src/plugins/intercept.rs new file mode 100644 index 00000000..6e250dca --- /dev/null +++ b/src/plugins/intercept.rs @@ -0,0 +1,278 @@ +//! The intercept plugin. +//! +//! It intercepts queries and returns fake results. + +use arc_swap::ArcSwap; +use async_trait::async_trait; +use bytes::{BufMut, BytesMut}; +use once_cell::sync::Lazy; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use sqlparser::ast::Statement; +use std::collections::HashMap; + +use log::debug; +use std::sync::Arc; + +use crate::{ + errors::Error, + messages::{command_complete, data_row_nullable, row_description, DataType}, + plugins::{Plugin, PluginOutput}, + pool::{PoolIdentifier, PoolMap}, + query_router::QueryRouter, +}; + +pub static CONFIG: Lazy>> = + Lazy::new(|| ArcSwap::from_pointee(HashMap::new())); + +/// Configure the intercept plugin. +pub fn configure(pools: &PoolMap) { + let mut config = HashMap::new(); + for (identifier, _) in pools.iter() { + // TODO: make this configurable from a text config. + let value = fool_datagrip(&identifier.db, &identifier.user); + config.insert(identifier.clone(), value); + } + + CONFIG.store(Arc::new(config)); +} + +// TODO: use these structs for deserialization +#[derive(Serialize, Deserialize)] +pub struct Rule { + query: String, + schema: Vec, + result: Vec>, +} + +#[derive(Serialize, Deserialize)] +pub struct Column { + name: String, + data_type: String, +} + +/// The intercept plugin. +pub struct Intercept; + +#[async_trait] +impl Plugin for Intercept { + async fn run( + &mut self, + query_router: &QueryRouter, + ast: &Vec, + ) -> Result { + if ast.is_empty() { + return Ok(PluginOutput::Allow); + } + + let mut result = BytesMut::new(); + let query_map = match CONFIG.load().get(&PoolIdentifier::new( + &query_router.pool_settings().db, + &query_router.pool_settings().user.username, + )) { + Some(query_map) => query_map.clone(), + None => return Ok(PluginOutput::Allow), + }; + + for q in ast { + // Normalization + let q = q.to_string().to_ascii_lowercase(); + + for target in query_map.as_array().unwrap().iter() { + if target["query"].as_str().unwrap() == q { + debug!("Query matched: {}", q); + + let rd = target["schema"] + .as_array() + .unwrap() + .iter() + .map(|row| { + let row = row.as_object().unwrap(); + ( + row["name"].as_str().unwrap(), + match row["data_type"].as_str().unwrap() { + "text" => DataType::Text, + "anyarray" => DataType::AnyArray, + "oid" => DataType::Oid, + "bool" => DataType::Bool, + "int4" => DataType::Int4, + _ => DataType::Any, + }, + ) + }) + .collect::>(); + + result.put(row_description(&rd)); + + target["result"].as_array().unwrap().iter().for_each(|row| { + let row = row + .as_array() + .unwrap() + .iter() + .map(|s| { + let s = s.as_str().unwrap().to_string(); + + if s == "" { + None + } else { + Some(s) + } + }) + .collect::>>(); + result.put(data_row_nullable(&row)); + }); + + result.put(command_complete("SELECT")); + } + } + } + + if !result.is_empty() { + result.put_u8(b'Z'); + result.put_i32(5); + result.put_u8(b'I'); + + return Ok(PluginOutput::Intercept(result)); + } else { + Ok(PluginOutput::Allow) + } + } +} + +/// Make IntelliJ SQL plugin believe it's talking to an actual database +/// instead of PgCat. +fn fool_datagrip(database: &str, user: &str) -> Value { + json!([ + { + "query": "select current_database() as a, current_schemas(false) as b", + "schema": [ + { + "name": "a", + "data_type": "text", + }, + { + "name": "b", + "data_type": "anyarray", + }, + ], + + "result": [ + [database, "{public}"], + ], + }, + { + "query": "select current_database(), current_schema(), current_user", + "schema": [ + { + "name": "current_database", + "data_type": "text", + }, + { + "name": "current_schema", + "data_type": "text", + }, + { + "name": "current_user", + "data_type": "text", + } + ], + + "result": [ + ["sharded_db", "public", "sharding_user"], + ], + }, + { + "query": "select cast(n.oid as bigint) as id, datname as name, d.description, datistemplate as is_template, datallowconn as allow_connections, pg_catalog.pg_get_userbyid(n.datdba) as \"owner\" from pg_catalog.pg_database as n left join pg_catalog.pg_shdescription as d on n.oid = d.objoid order by case when datname = pg_catalog.current_database() then -cast(1 as bigint) else cast(n.oid as bigint) end", + "schema": [ + { + "name": "id", + "data_type": "oid", + }, + { + "name": "name", + "data_type": "text", + }, + { + "name": "description", + "data_type": "text", + }, + { + "name": "is_template", + "data_type": "bool", + }, + { + "name": "allow_connections", + "data_type": "bool", + }, + { + "name": "owner", + "data_type": "text", + } + ], + "result": [ + ["16387", database, "", "f", "t", user], + ] + }, + { + "query": "select cast(r.oid as bigint) as role_id, rolname as role_name, rolsuper as is_super, rolinherit as is_inherit, rolcreaterole as can_createrole, rolcreatedb as can_createdb, rolcanlogin as can_login, rolreplication as is_replication, rolconnlimit as conn_limit, rolvaliduntil as valid_until, rolbypassrls as bypass_rls, rolconfig as config, d.description from pg_catalog.pg_roles as r left join pg_catalog.pg_shdescription as d on d.objoid = r.oid", + "schema": [ + { + "name": "role_id", + "data_type": "oid", + }, + { + "name": "role_name", + "data_type": "text", + }, + { + "name": "is_super", + "data_type": "bool", + }, + { + "name": "is_inherit", + "data_type": "bool", + }, + { + "name": "can_createrole", + "data_type": "bool", + }, + { + "name": "can_createdb", + "data_type": "bool", + }, + { + "name": "can_login", + "data_type": "bool", + }, + { + "name": "is_replication", + "data_type": "bool", + }, + { + "name": "conn_limit", + "data_type": "int4", + }, + { + "name": "valid_until", + "data_type": "text", + }, + { + "name": "bypass_rls", + "data_type": "bool", + }, + { + "name": "config", + "data_type": "text", + }, + { + "name": "description", + "data_type": "text", + }, + ], + "result": [ + ["10", "postgres", "f", "t", "f", "f", "t", "f", "-1", "", "f", "", ""], + ["16419", user, "f", "t", "f", "f", "t", "f", "-1", "", "f", "", ""], + ] + } + ]) +} diff --git a/src/plugins/mod.rs b/src/plugins/mod.rs new file mode 100644 index 00000000..92fa70b7 --- /dev/null +++ b/src/plugins/mod.rs @@ -0,0 +1,40 @@ +//! The plugin ecosystem. +//! +//! Currently plugins only grant access or deny access to the database for a particual query. +//! Example use cases: +//! - block known bad queries +//! - block access to system catalogs +//! - block dangerous modifications like `DROP TABLE` +//! - etc +//! + +pub mod intercept; +pub mod table_access; + +use crate::{errors::Error, query_router::QueryRouter}; +use async_trait::async_trait; +use bytes::BytesMut; +use sqlparser::ast::Statement; + +pub use intercept::Intercept; +pub use table_access::TableAccess; + +#[derive(Clone, Debug, PartialEq)] +pub enum PluginOutput { + Allow, + Deny(String), + Overwrite(Vec), + Intercept(BytesMut), +} + +#[async_trait] +pub trait Plugin { + // Custom output is allowed because we want to extend this system + // to rewriting queries some day. So an output of a plugin could be + // a rewritten AST. + async fn run( + &mut self, + query_router: &QueryRouter, + ast: &Vec, + ) -> Result; +} diff --git a/src/plugins/table_access.rs b/src/plugins/table_access.rs new file mode 100644 index 00000000..2e23278a --- /dev/null +++ b/src/plugins/table_access.rs @@ -0,0 +1,50 @@ +//! This query router plugin will check if the user can access a particular +//! table as part of their query. If they can't, the query will not be routed. + +use async_trait::async_trait; +use sqlparser::ast::{visit_relations, Statement}; + +use crate::{ + errors::Error, + plugins::{Plugin, PluginOutput}, + query_router::QueryRouter, +}; + +use core::ops::ControlFlow; + +pub struct TableAccess { + pub forbidden_tables: Vec, +} + +#[async_trait] +impl Plugin for TableAccess { + async fn run( + &mut self, + _query_router: &QueryRouter, + ast: &Vec, + ) -> Result { + let mut found = None; + + visit_relations(ast, |relation| { + let relation = relation.to_string(); + let parts = relation.split(".").collect::>(); + let table_name = parts.last().unwrap(); + + if self.forbidden_tables.contains(&table_name.to_string()) { + found = Some(table_name.to_string()); + ControlFlow::<()>::Break(()) + } else { + ControlFlow::<()>::Continue(()) + } + }); + + if let Some(found) = found { + Ok(PluginOutput::Deny(format!( + "permission for table \"{}\" denied", + found + ))) + } else { + Ok(PluginOutput::Allow) + } + } +} diff --git a/src/pool.rs b/src/pool.rs index ee8de446..b986548a 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -61,6 +61,8 @@ pub struct PoolIdentifier { pub user: String, } +static POOL_REAPER_RATE: u64 = 30_000; // 30 seconds by default + impl PoolIdentifier { /// Create a new user/pool identifier. pub fn new(db: &str, user: &str) -> PoolIdentifier { @@ -91,6 +93,7 @@ pub struct PoolSettings { // Connecting user. pub user: User, + pub db: String, // Default server role to connect to. pub default_role: Option, @@ -129,6 +132,8 @@ pub struct PoolSettings { pub auth_query: Option, pub auth_query_user: Option, pub auth_query_password: Option, + + pub plugins: Option>, } impl Default for PoolSettings { @@ -138,6 +143,7 @@ impl Default for PoolSettings { load_balancing_mode: LoadBalancingMode::Random, shards: 1, user: User::default(), + db: String::default(), default_role: None, query_parser_enabled: false, primary_reads_enabled: true, @@ -152,6 +158,7 @@ impl Default for PoolSettings { auth_query: None, auth_query_user: None, auth_query_password: None, + plugins: None, } } } @@ -368,12 +375,20 @@ impl ConnectionPool { }, }; + let reaper_rate = *vec![idle_timeout, server_lifetime, POOL_REAPER_RATE] + .iter() + .min() + .unwrap(); + + debug!("Pool reaper rate: {}ms", reaper_rate); + let pool = Pool::builder() .max_size(user.pool_size) .min_idle(user.min_pool_size) .connection_timeout(std::time::Duration::from_millis(connect_timeout)) .idle_timeout(Some(std::time::Duration::from_millis(idle_timeout))) .max_lifetime(Some(std::time::Duration::from_millis(server_lifetime))) + .reaper_rate(std::time::Duration::from_millis(reaper_rate)) .test_on_check_out(false) .build(manager) .await?; @@ -412,6 +427,7 @@ impl ConnectionPool { // shards: pool_config.shards.clone(), shards: shard_ids.len(), user: user.clone(), + db: pool_name.clone(), default_role: match pool_config.default_role.as_str() { "any" => None, "replica" => Some(Role::Replica), @@ -437,6 +453,7 @@ impl ConnectionPool { auth_query: pool_config.auth_query.clone(), auth_query_user: pool_config.auth_query_user.clone(), auth_query_password: pool_config.auth_query_password.clone(), + plugins: config.general.query_router_plugins.clone(), }, validated: Arc::new(AtomicBool::new(false)), paused: Arc::new(AtomicBool::new(false)), @@ -456,6 +473,13 @@ impl ConnectionPool { } } + // Initialize plugins here if required. + if let Some(plugins) = config.general.query_router_plugins { + if plugins.contains(&String::from("intercept")) { + crate::plugins::intercept::configure(&new_pools); + } + } + POOLS.store(Arc::new(new_pools.clone())); Ok(()) } diff --git a/src/query_router.rs b/src/query_router.rs index 5b2ba0c4..93bcd4f2 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -6,13 +6,16 @@ use once_cell::sync::OnceCell; use regex::{Regex, RegexSet}; use sqlparser::ast::Statement::{Query, StartTransaction}; use sqlparser::ast::{ - BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, TableFactor, Value, + BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement, TableFactor, + Value, }; use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::Parser; use crate::config::Role; +use crate::errors::Error; use crate::messages::BytesMutReader; +use crate::plugins::{Intercept, Plugin, PluginOutput, TableAccess}; use crate::pool::PoolSettings; use crate::sharding::Sharder; @@ -129,6 +132,10 @@ impl QueryRouter { self.pool_settings = pool_settings; } + pub fn pool_settings<'a>(&'a self) -> &'a PoolSettings { + &self.pool_settings + } + /// Try to parse a command and execute it. pub fn try_execute_command(&mut self, message_buffer: &BytesMut) -> Option<(Command, String)> { let mut message_cursor = Cursor::new(message_buffer); @@ -324,10 +331,7 @@ impl QueryRouter { Some((command, value)) } - /// Try to infer which server to connect to based on the contents of the query. - pub fn infer(&mut self, message: &BytesMut) -> bool { - debug!("Inferring role"); - + pub fn parse(message: &BytesMut) -> Result, Error> { let mut message_cursor = Cursor::new(message); let code = message_cursor.get_u8() as char; @@ -353,28 +357,29 @@ impl QueryRouter { query } - _ => return false, + _ => return Err(Error::UnsupportedStatement), }; - let ast = match Parser::parse_sql(&PostgreSqlDialect {}, &query) { - Ok(ast) => ast, + match Parser::parse_sql(&PostgreSqlDialect {}, &query) { + Ok(ast) => Ok(ast), Err(err) => { - // SELECT ... FOR UPDATE won't get parsed correctly. debug!("{}: {}", err, query); - self.active_role = Some(Role::Primary); - return false; + Err(Error::QueryRouterParserError(err.to_string())) } - }; + } + } - debug!("AST: {:?}", ast); + /// Try to infer which server to connect to based on the contents of the query. + pub fn infer(&mut self, ast: &Vec) -> Result<(), Error> { + debug!("Inferring role"); if ast.is_empty() { // That's weird, no idea, let's go to primary self.active_role = Some(Role::Primary); - return false; + return Err(Error::QueryRouterParserError("empty query".into())); } - for q in &ast { + for q in ast { match q { // All transactions go to the primary, probably a write. StartTransaction { .. } => { @@ -418,7 +423,7 @@ impl QueryRouter { }; } - true + Ok(()) } /// Parse the shard number from the Bind message @@ -783,6 +788,32 @@ impl QueryRouter { } } + /// Add your plugins here and execute them. + pub async fn execute_plugins(&self, ast: &Vec) -> Result { + if let Some(plugins) = &self.pool_settings.plugins { + if plugins.contains(&String::from("intercept")) { + let mut intercept = Intercept {}; + let result = intercept.run(&self, ast).await; + + if let Ok(PluginOutput::Intercept(output)) = result { + return Ok(PluginOutput::Intercept(output)); + } + } + + if plugins.contains(&String::from("pg_table_access")) { + let mut table_access = TableAccess { + forbidden_tables: vec![String::from("pg_database"), String::from("pg_roles")], + }; + + if let Ok(PluginOutput::Deny(error)) = table_access.run(&self, ast).await { + return Ok(PluginOutput::Deny(error)); + } + } + } + + Ok(PluginOutput::Allow) + } + fn set_sharding_key(&mut self, sharding_key: i64) -> Option { let sharder = Sharder::new( self.pool_settings.shards, @@ -810,11 +841,22 @@ impl QueryRouter { /// Should we attempt to parse queries? pub fn query_parser_enabled(&self) -> bool { let enabled = match self.query_parser_enabled { - None => self.pool_settings.query_parser_enabled, - Some(value) => value, - }; + None => { + debug!( + "Using pool settings, query_parser_enabled: {}", + self.pool_settings.query_parser_enabled + ); + self.pool_settings.query_parser_enabled + } - debug!("Query parser enabled: {}", enabled); + Some(value) => { + debug!( + "Using query parser override, query_parser_enabled: {}", + value + ); + value + } + }; enabled } @@ -862,7 +904,7 @@ mod test { for query in queries { // It's a recognized query - assert!(qr.infer(&query)); + assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok()); assert_eq!(qr.role(), Some(Role::Replica)); } } @@ -881,7 +923,7 @@ mod test { for query in queries { // It's a recognized query - assert!(qr.infer(&query)); + assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok()); assert_eq!(qr.role(), Some(Role::Primary)); } } @@ -893,7 +935,7 @@ mod test { let query = simple_query("SELECT * FROM items WHERE id = 5"); assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None); - assert!(qr.infer(&query)); + assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok()); assert_eq!(qr.role(), None); } @@ -913,7 +955,7 @@ mod test { res.put(prepared_stmt); res.put_i16(0); - assert!(qr.infer(&res)); + assert!(qr.infer(&QueryRouter::parse(&res).unwrap()).is_ok()); assert_eq!(qr.role(), Some(Role::Replica)); } @@ -1077,11 +1119,11 @@ mod test { assert_eq!(qr.role(), None); let query = simple_query("INSERT INTO test_table VALUES (1)"); - assert!(qr.infer(&query)); + assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok()); assert_eq!(qr.role(), Some(Role::Primary)); let query = simple_query("SELECT * FROM test_table"); - assert!(qr.infer(&query)); + assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok()); assert_eq!(qr.role(), Some(Role::Replica)); assert!(qr.query_parser_enabled()); @@ -1113,6 +1155,8 @@ mod test { auth_query: None, auth_query_password: None, auth_query_user: None, + db: "test".to_string(), + plugins: None, }; let mut qr = QueryRouter::new(); assert_eq!(qr.active_role, None); @@ -1142,15 +1186,24 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); - assert!(qr.infer(&simple_query("BEGIN; SELECT 1; COMMIT;"))); + assert!(qr + .infer(&QueryRouter::parse(&simple_query("BEGIN; SELECT 1; COMMIT;")).unwrap()) + .is_ok()); assert_eq!(qr.role(), Role::Primary); - assert!(qr.infer(&simple_query("SELECT 1; SELECT 2;"))); + assert!(qr + .infer(&QueryRouter::parse(&simple_query("SELECT 1; SELECT 2;")).unwrap()) + .is_ok()); assert_eq!(qr.role(), Role::Replica); - assert!(qr.infer(&simple_query( - "SELECT 123; INSERT INTO t VALUES (5); SELECT 1;" - ))); + assert!(qr + .infer( + &QueryRouter::parse(&simple_query( + "SELECT 123; INSERT INTO t VALUES (5); SELECT 1;" + )) + .unwrap() + ) + .is_ok()); assert_eq!(qr.role(), Role::Primary); } @@ -1177,6 +1230,8 @@ mod test { auth_query: None, auth_query_password: None, auth_query_user: None, + db: "test".to_string(), + plugins: None, }; let mut qr = QueryRouter::new(); qr.update_pool_settings(pool_settings.clone()); @@ -1208,47 +1263,84 @@ mod test { qr.pool_settings.automatic_sharding_key = Some("data.id".to_string()); qr.pool_settings.shards = 3; - assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 5"))); + assert!(qr + .infer(&QueryRouter::parse(&simple_query("SELECT * FROM data WHERE id = 5")).unwrap()) + .is_ok()); assert_eq!(qr.shard(), 2); - assert!(qr.infer(&simple_query( - "SELECT one, two, three FROM public.data WHERE id = 6" - ))); + assert!(qr + .infer( + &QueryRouter::parse(&simple_query( + "SELECT one, two, three FROM public.data WHERE id = 6" + )) + .unwrap() + ) + .is_ok()); assert_eq!(qr.shard(), 0); - assert!(qr.infer(&simple_query( - "SELECT * FROM data + assert!(qr + .infer( + &QueryRouter::parse(&simple_query( + "SELECT * FROM data INNER JOIN t2 ON data.id = 5 AND t2.data_id = data.id WHERE data.id = 5" - ))); + )) + .unwrap() + ) + .is_ok()); assert_eq!(qr.shard(), 2); // Shard did not move because we couldn't determine the sharding key since it could be ambiguous // in the query. - assert!(qr.infer(&simple_query( - "SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id" - ))); + assert!(qr + .infer( + &QueryRouter::parse(&simple_query( + "SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id" + )) + .unwrap() + ) + .is_ok()); assert_eq!(qr.shard(), 2); - assert!(qr.infer(&simple_query( - r#"SELECT * FROM "public"."data" WHERE "id" = 6"# - ))); + assert!(qr + .infer( + &QueryRouter::parse(&simple_query( + r#"SELECT * FROM "public"."data" WHERE "id" = 6"# + )) + .unwrap() + ) + .is_ok()); assert_eq!(qr.shard(), 0); - assert!(qr.infer(&simple_query( - r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"# - ))); + assert!(qr + .infer( + &QueryRouter::parse(&simple_query( + r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"# + )) + .unwrap() + ) + .is_ok()); assert_eq!(qr.shard(), 2); // Super unique sharding key qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string()); - assert!(qr.infer(&simple_query( - "SELECT * FROM table_x WHERE unique_enough_column_name = 6" - ))); + assert!(qr + .infer( + &QueryRouter::parse(&simple_query( + "SELECT * FROM table_x WHERE unique_enough_column_name = 6" + )) + .unwrap() + ) + .is_ok()); assert_eq!(qr.shard(), 0); - assert!(qr.infer(&simple_query("SELECT * FROM table_y WHERE another_key = 5"))); + assert!(qr + .infer( + &QueryRouter::parse(&simple_query("SELECT * FROM table_y WHERE another_key = 5")) + .unwrap() + ) + .is_ok()); assert_eq!(qr.shard(), 0); } @@ -1272,11 +1364,36 @@ mod test { qr.pool_settings.automatic_sharding_key = Some("data.id".to_string()); qr.pool_settings.shards = 3; - assert!(qr.infer(&simple_query(stmt))); + assert!(qr + .infer(&QueryRouter::parse(&simple_query(stmt)).unwrap()) + .is_ok()); assert_eq!(qr.placeholders.len(), 1); assert!(qr.infer_shard_from_bind(&bind)); assert_eq!(qr.shard(), 2); assert!(qr.placeholders.is_empty()); } + + #[tokio::test] + async fn test_table_access_plugin() { + QueryRouter::setup(); + + let mut qr = QueryRouter::new(); + + let mut pool_settings = PoolSettings::default(); + pool_settings.plugins = Some(vec![String::from("pg_table_access")]); + qr.update_pool_settings(pool_settings); + + let query = simple_query("SELECT * FROM pg_database"); + let ast = QueryRouter::parse(&query).unwrap(); + + let res = qr.execute_plugins(&ast).await; + + assert_eq!( + res, + Ok(PluginOutput::Deny( + "permission for table \"pg_database\" denied".to_string() + )) + ); + } } diff --git a/tests/ruby/helpers/pgcat_helper.rb b/tests/ruby/helpers/pgcat_helper.rb index ad4c32a4..eb0cdaa9 100644 --- a/tests/ruby/helpers/pgcat_helper.rb +++ b/tests/ruby/helpers/pgcat_helper.rb @@ -27,6 +27,7 @@ def self.three_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mod primary2 = PgInstance.new(8432, user["username"], user["password"], "shard2") pgcat_cfg = pgcat.current_config + pgcat_cfg["general"]["query_router_plugins"] = ["intercept"] pgcat_cfg["pools"] = { "#{pool_name}" => { "default_role" => "any", diff --git a/tests/ruby/plugins_spec.rb b/tests/ruby/plugins_spec.rb new file mode 100644 index 00000000..d4e233ab --- /dev/null +++ b/tests/ruby/plugins_spec.rb @@ -0,0 +1,14 @@ +require_relative 'spec_helper' + + +describe "Plugins" do + let(:processes) { Helpers::Pgcat.three_shard_setup("sharded_db", 5) } + + context "intercept" do + it "will intercept an intellij query" do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + res = conn.exec("select current_database() as a, current_schemas(false) as b") + expect(res.values).to eq([["sharded_db", "{public}"]]) + end + end +end