diff --git a/.circleci/config.yml b/.circleci/config.yml index 5e2d114a..c7f5c9fa 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -46,6 +46,14 @@ jobs: POSTGRES_PASSWORD: postgres POSTGRES_INITDB_ARGS: --auth-local=scram-sha-256 --auth-host=scram-sha-256 --auth=scram-sha-256 + - image: postgres:14 + command: ["postgres", "-p", "10432", "-c", "shared_preload_libraries=pg_stat_statements"] + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_INITDB_ARGS: --auth-local=md5 --auth-host=md5 --auth=md5 + # Add steps to the job # See: https://circleci.com/docs/2.0/configuration-reference/#steps steps: diff --git a/.circleci/run_tests.sh b/.circleci/run_tests.sh index a5cfab0b..4ba497c3 100644 --- a/.circleci/run_tests.sh +++ b/.circleci/run_tests.sh @@ -19,6 +19,7 @@ PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 5432 -U postgres -f tests/sharding/q PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 7432 -U postgres -f tests/sharding/query_routing_setup.sql PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 8432 -U postgres -f tests/sharding/query_routing_setup.sql PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 9432 -U postgres -f tests/sharding/query_routing_setup.sql +PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 10432 -U postgres -f tests/sharding/query_routing_setup.sql PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard0 -i PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard1 -i diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 00000000..17f33211 --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1,2 @@ +edition = "2021" +hard_tabs = false diff --git a/CONFIG.md b/CONFIG.md index bcd6f09f..3cec2530 100644 --- a/CONFIG.md +++ b/CONFIG.md @@ -175,11 +175,41 @@ Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DAT ### admin_password ``` path: general.admin_password -default: "admin_pass" +default: ``` Password to access the virtual administrative database +### auth_query (experimental) +``` +path: general.auth_query +default: +``` + +Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be +established using the database configured in the pool. This parameter is inherited by every pool +and can be redefined in pool configuration. + +### auth_query_user (experimental) +``` +path: general.auth_query_user +default: +``` + +User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query +specified in `auth_query_user`. The connection will be established using the database configured in the pool. +This parameter is inherited by every pool and can be redefined in pool configuration. + +### auth_query_password (experimental) +``` +path: general.auth_query_password +default: +``` + +Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query +specified in `auth_query_user`. The connection will be established using the database configured in the pool. +This parameter is inherited by every pool and can be redefined in pool configuration. + ## `pools.` Section ### pool_mode @@ -281,6 +311,30 @@ default: 3000 Connect timeout can be overwritten in the pool +### auth_query (experimental) +``` +path: general.auth_query +default: +``` + +Auth query can be overwritten in the pool + +### auth_query_user (experimental) +``` +path: general.auth_query_user +default: +``` + +Auth query user can be overwritten in the pool + +### auth_query_password (experimental) +``` +path: general.auth_query_password +default: +``` + +Auth query password can be overwritten in the pool + ## `pools..users.` Section ### username diff --git a/Cargo.lock b/Cargo.lock index 7e4aa683..7b3d3778 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -45,7 +45,7 @@ checksum = "6227a8d6fdb862bcb100c4314d0d9579e5cd73fa6df31a2e6f6e1acd3c5f1207" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -54,6 +54,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + [[package]] name = "base64" version = "0.21.0" @@ -94,6 +100,12 @@ version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" +[[package]] +name = "byteorder" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" + [[package]] name = "bytes" version = "1.4.0" @@ -257,6 +269,12 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de853764b47027c2e862a995c34978ffa63c1501f2e15f987ba11bd4f9bba193" +[[package]] +name = "fallible-iterator" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" + [[package]] name = "fnv" version = "1.0.7" @@ -732,12 +750,13 @@ dependencies = [ "arc-swap", "async-trait", "atomic_enum", - "base64", + "base64 0.21.0", "bb8", "bytes", "chrono", "env_logger", "exitcode", + "fallible-iterator", "futures", "hmac", "hyper", @@ -749,6 +768,7 @@ dependencies = [ "once_cell", "parking_lot", "phf", + "postgres-protocol", "rand", "regex", "rustls-pemfile", @@ -818,6 +838,24 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "postgres-protocol" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "878c6cbf956e03af9aa8204b407b9cbf47c072164800aa918c516cd4b056c50c" +dependencies = [ + "base64 0.13.1", + "byteorder", + "bytes", + "fallible-iterator", + "hmac", + "md-5", + "memchr", + "rand", + "sha2", + "stringprep", +] + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -945,7 +983,7 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" dependencies = [ - "base64", + "base64 0.21.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 3f0cbec6..89cfe643 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,6 @@ version = "1.0.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] tokio = { version = "1", features = ["full"] } bytes = "1" @@ -38,6 +37,8 @@ futures = "0.3" socket2 = { version = "0.4.7", features = ["all"] } nix = "0.26.2" atomic_enum = "0.2.0" +postgres-protocol = "0.6.4" +fallible-iterator = "0.2" [target.'cfg(not(target_env = "msvc"))'.dependencies] jemallocator = "0.5.0" diff --git a/README.md b/README.md index 4d6f599d..63b5ab1a 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ PostgreSQL pooler and proxy (like PgBouncer) with support for sharding, load bal | Sharding using comments parsing/Regex | **Experimental** | Clients can include shard information (sharding key, shard ID) in the query comments. | | Automatic sharding | **Experimental** | PgCat can parse queries, detect sharding keys automatically, and route queries to the correct shard. | | Mirroring | **Experimental** | Mirror queries between multiple databases in order to test servers with realistic production traffic. | +| Auth passthrough | **Experimental** | MD5 password authentication can be configured to use an `auth_query` so no cleartext passwords are needed in the config file. | ## Status diff --git a/dev/docker-compose.yaml b/dev/docker-compose.yaml index 15621e87..71704bcb 100644 --- a/dev/docker-compose.yaml +++ b/dev/docker-compose.yaml @@ -58,6 +58,13 @@ services: POSTGRES_INITDB_ARGS: --auth-local=scram-sha-256 --auth-host=scram-sha-256 --auth=scram-sha-256 PGPORT: 9432 command: ["postgres", "-p", "9432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"] + pg5: + <<: *common-definition-pg + environment: + <<: *common-env-pg + POSTGRES_INITDB_ARGS: --auth-local=md5 --auth-host=md5 --auth=md5 + PGPORT: 10432 + command: ["postgres", "-p", "5432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"] toxiproxy: build: . @@ -71,6 +78,7 @@ services: - pg2 - pg3 - pg4 + - pg5 pgcat-shell: stdin_open: true diff --git a/src/auth_passthrough.rs b/src/auth_passthrough.rs new file mode 100644 index 00000000..b9f0e97f --- /dev/null +++ b/src/auth_passthrough.rs @@ -0,0 +1,107 @@ +use crate::errors::Error; +use crate::server::Server; +use log::debug; + +#[derive(Clone, Debug)] +pub struct AuthPassthrough { + password: String, + query: String, + user: String, +} + +impl AuthPassthrough { + /// Initializes an AuthPassthrough. + pub fn new(query: &str, user: &str, password: &str) -> Self { + AuthPassthrough { + password: password.to_string(), + query: query.to_string(), + user: user.to_string(), + } + } + + /// Returns an AuthPassthrough given the pool configuration. + /// If any of required values is not set, None is returned. + pub fn from_pool_config(pool_config: &crate::config::Pool) -> Option { + if pool_config.is_auth_query_configured() { + return Some(AuthPassthrough::new( + pool_config.auth_query.as_ref().unwrap(), + pool_config.auth_query_user.as_ref().unwrap(), + pool_config.auth_query_password.as_ref().unwrap(), + )); + } + + None + } + + /// Returns an AuthPassthrough given the pool settings. + /// If any of required values is not set, None is returned. + pub fn from_pool_settings(pool_settings: &crate::pool::PoolSettings) -> Option { + let pool_config = crate::config::Pool { + auth_query: pool_settings.auth_query.clone(), + auth_query_password: pool_settings.auth_query_password.clone(), + auth_query_user: pool_settings.auth_query_user.clone(), + ..Default::default() + }; + + AuthPassthrough::from_pool_config(&pool_config) + } + + /// Connects to server and executes auth_query for the specified address. + /// If the response is a row with two columns containing the username set in the address. + /// and its MD5 hash, the MD5 hash returned. + /// + /// Note that the query is executed, changing $1 with the name of the user + /// this is so we only hold in memory (and transfer) the least amount of 'sensitive' data. + /// Also, it is compatible with pgbouncer. + /// + /// # Arguments + /// + /// * `address` - An Address of the server we want to connect to. The username for the hash will be obtained from this value. + /// + /// # Examples + /// + /// ``` + /// use pgcat::auth_passthrough::AuthPassthrough; + /// use pgcat::config::Address; + /// let auth_passthrough = AuthPassthrough::new("SELECT * FROM public.user_lookup('$1');", "postgres", "postgres"); + /// auth_passthrough.fetch_hash(&Address::default()); + /// ``` + /// + pub async fn fetch_hash(&self, address: &crate::config::Address) -> Result { + let auth_user = crate::config::User { + username: self.user.clone(), + password: Some(self.password.clone()), + pool_size: 1, + statement_timeout: 0, + }; + + let user = &address.username; + + debug!("Connecting to server to obtain auth hashes."); + let auth_query = self.query.replace("$1", user); + match Server::exec_simple_query(address, &auth_user, &auth_query).await { + Ok(password_data) => { + if password_data.len() == 2 && password_data.first().unwrap() == user { + if let Some(stripped_hash) = password_data.last().unwrap().to_string().strip_prefix("md5") { + Ok(stripped_hash.to_string()) + } + else { + Err(Error::AuthPassthroughError( + "Obtained hash from auth_query does not seem to be in md5 format.".to_string(), + )) + } + } else { + Err(Error::AuthPassthroughError( + "Data obtained from query does not follow the scheme 'user','hash'." + .to_string(), + )) + } + } + Err(err) => { + Err(Error::AuthPassthroughError( + format!("Error trying to obtain password from auth_query, ignoring hash for user '{}'. Error: {:?}", + user, err))) + } + } + } +} diff --git a/src/client.rs b/src/client.rs index f9f4e015..d75c069d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -12,9 +12,9 @@ use tokio::sync::broadcast::Receiver; use tokio::sync::mpsc::Sender; use crate::admin::{generate_server_info_for_admin, handle_admin}; +use crate::auth_passthrough::AuthPassthrough; use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode}; use crate::constants::*; - use crate::messages::*; use crate::pool::{get_pool, ClientServerMap, ConnectionPool}; use crate::query_router::{Command, QueryRouter}; @@ -377,6 +377,20 @@ pub async fn startup_tls( } } +async fn refetch_auth_hash(pool: &ConnectionPool) -> Result { + let address = pool.address(0, 0); + if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) { + let hash = apt.fetch_hash(address).await?; + + return Ok(hash); + } + + Err(Error::ClientError(format!( + "Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.", + address.username, address.database + ))) +} + impl Client where S: tokio::io::AsyncRead + std::marker::Unpin, @@ -509,14 +523,68 @@ where } }; - // Compare server and client hashes. - let password_hash = md5_hash_password(username, &pool.settings.user.password, &salt); + // Obtain the hash to compare, we give preference to that written in cleartext in config + // if there is nothing set in cleartext and auth passthrough (auth_query) is configured, we use the hash obtained + // when the pool was created. If there is no hash there, we try to fetch it one more time. + let password_hash = if let Some(password) = &pool.settings.user.password { + Some(md5_hash_password(username, password, &salt)) + } else { + if !get_config().is_auth_query_configured() { + return Err(Error::ClientError(format!("Client auth not possible, no cleartext password set for username: {:?} in config and auth passthrough (query_auth) is not set up.", username))); + } - if password_hash != password_response { - warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name); - wrong_password(&mut write, username).await?; + let mut hash = (*pool.auth_hash.read()).clone(); - return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); + if hash.is_none() { + warn!("Query auth configured but no hash password found for pool {}. Will try to refetch it.", pool_name); + match refetch_auth_hash(&pool).await { + Ok(fetched_hash) => { + warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, obtained. Updating.", username, pool_name, application_name); + { + let mut pool_auth_hash = pool.auth_hash.write(); + *pool_auth_hash = Some(fetched_hash.clone()); + } + + hash = Some(fetched_hash); + } + Err(err) => { + return Err( + Error::ClientError( + format!("No cleartext password set, and no auth passthrough could not obtain the hash from server for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, the error was: {:?}", + username, + pool_name, + application_name, + err) + ) + ); + } + } + }; + + Some(md5_hash_second_pass(&hash.unwrap(), &salt)) + }; + + // Once we have the resulting hash, we compare with what the client gave us. + // If they do not match and auth query is set up, we try to refetch the hash one more time + // to see if the password has changed since the pool was created. + // + // @TODO: we could end up fetching again the same password twice (see above). + if password_hash.unwrap() != password_response { + warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, will try to refetch it.", username, pool_name, application_name); + let fetched_hash = refetch_auth_hash(&pool).await?; + let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt); + + // Ok password changed in server an auth is possible. + if new_password_hash == password_response { + warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, changed in server. Updating.", username, pool_name, application_name); + { + let mut pool_auth_hash = pool.auth_hash.write(); + *pool_auth_hash = Some(fetched_hash); + } + } else { + wrong_password(&mut write, username).await?; + return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); + } } let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction; diff --git a/src/config.rs b/src/config.rs index 644532ac..6545457c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -177,7 +177,7 @@ impl Address { #[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize, Debug)] pub struct User { pub username: String, - pub password: String, + pub password: Option, pub pool_size: u32, #[serde(default)] // 0 pub statement_timeout: u64, @@ -187,7 +187,7 @@ impl Default for User { fn default() -> User { User { username: String::from("postgres"), - password: String::new(), + password: None, pool_size: 15, statement_timeout: 0, } @@ -250,6 +250,10 @@ pub struct General { pub tls_private_key: Option, pub admin_username: String, pub admin_password: String, + + pub auth_query: Option, + pub auth_query_user: Option, + pub auth_query_password: Option, } impl General { @@ -334,6 +338,9 @@ impl Default for General { tls_private_key: None, admin_username: String::from("admin"), admin_password: String::from("admin"), + auth_query: None, + auth_query_user: None, + auth_query_password: None, } } } @@ -406,6 +413,10 @@ pub struct Pool { pub shard_id_regex: Option, pub regex_search_limit: Option, + pub auth_query: Option, + pub auth_query_user: Option, + pub auth_query_password: Option, + pub shards: BTreeMap, pub users: BTreeMap, // Note, don't put simple fields below these configs. There's a compatability issue with TOML that makes it @@ -420,6 +431,12 @@ impl Pool { s.finish() } + pub fn is_auth_query_configured(&self) -> bool { + self.auth_query_password.is_some() + && self.auth_query_user.is_some() + && self.auth_query_password.is_some() + } + pub fn default_pool_mode() -> PoolMode { PoolMode::Transaction } @@ -512,6 +529,9 @@ impl Default for Pool { sharding_key_regex: None, shard_id_regex: None, regex_search_limit: Some(1000), + auth_query: None, + auth_query_user: None, + auth_query_password: None, } } } @@ -612,9 +632,31 @@ pub struct Config { } impl Config { + pub fn is_auth_query_configured(&self) -> bool { + self.pools + .iter() + .any(|(_name, pool)| pool.is_auth_query_configured()) + } + pub fn default_path() -> String { String::from("pgcat.toml") } + + pub fn fill_up_auth_query_config(&mut self) { + for (_name, pool) in self.pools.iter_mut() { + if pool.auth_query.is_none() { + pool.auth_query = self.general.auth_query.clone(); + } + + if pool.auth_query_user.is_none() { + pool.auth_query_user = self.general.auth_query_user.clone(); + } + + if pool.auth_query_password.is_none() { + pool.auth_query_password = self.general.auth_query_password.clone(); + } + } + } } impl Default for Config { @@ -832,6 +874,35 @@ impl Config { } pub fn validate(&mut self) -> Result<(), Error> { + // Validation for auth_query feature + if self.general.auth_query.is_some() + && (self.general.auth_query_user.is_none() + || self.general.auth_query_password.is_none()) + { + error!("If auth_query is specified, you need to provide a value for `auth_query_user`, `auth_query_password`"); + return Err(Error::BadConfig); + } + + for (name, pool) in self.pools.iter() { + if pool.auth_query.is_some() + && (pool.auth_query_user.is_none() || pool.auth_query_password.is_none()) + { + error!("Error in pool {{ {} }}. If auth_query is specified, you need to provide a value for `auth_query_user`, `auth_query_password`", name); + return Err(Error::BadConfig); + } + + for (_name, user_data) in pool.users.iter() { + if (pool.auth_query.is_none() + || pool.auth_query_password.is_none() + || pool.auth_query_user.is_none()) + && user_data.password.is_none() + { + error!("Error in pool {{ {} }}. You have to specify a user password for every pool if auth_query is not specified", name); + return Err(Error::BadConfig); + } + } + } + // Validate TLS! match self.general.tls_certificate.clone() { Some(tls_certificate) => { @@ -911,6 +982,7 @@ pub async fn parse(path: &str) -> Result<(), Error> { } }; + config.fill_up_auth_query_config(); config.validate()?; config.path = path.to_string(); @@ -980,7 +1052,10 @@ mod test { "sharding_user" ); assert_eq!( - get_config().pools["sharded_db"].users["1"].password, + get_config().pools["sharded_db"].users["1"] + .password + .as_ref() + .unwrap(), "other_user" ); assert_eq!(get_config().pools["sharded_db"].users["1"].pool_size, 21); @@ -1005,10 +1080,16 @@ mod test { "simple_user" ); assert_eq!( - get_config().pools["simple_db"].users["0"].password, + get_config().pools["simple_db"].users["0"] + .password + .as_ref() + .unwrap(), "simple_user" ); assert_eq!(get_config().pools["simple_db"].users["0"].pool_size, 5); + assert_eq!(get_config().general.auth_query, None); + assert_eq!(get_config().general.auth_query_user, None); + assert_eq!(get_config().general.auth_query_password, None); } #[tokio::test] diff --git a/src/errors.rs b/src/errors.rs index 310243c0..58fc088b 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -15,4 +15,6 @@ pub enum Error { StatementTimeout, ShuttingDown, ParseBytesError(String), + AuthError(String), + AuthPassthroughError(String), } diff --git a/src/lib.rs b/src/lib.rs index 67aa9cba..2645cd42 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod auth_passthrough; pub mod config; pub mod constants; pub mod errors; diff --git a/src/main.rs b/src/main.rs index a59da210..4c8987f1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -61,6 +61,7 @@ use std::sync::Arc; use tokio::sync::broadcast; mod admin; +mod auth_passthrough; mod client; mod config; mod constants; diff --git a/src/messages.rs b/src/messages.rs index c9ace4e0..61c36c6d 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -213,7 +213,13 @@ pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec { let output = md5.finalize_reset(); // Second pass - md5.update(format!("{:x}", output)); + md5_hash_second_pass(&(format!("{:x}", output)), salt) +} + +pub fn md5_hash_second_pass(hash: &str, salt: &[u8]) -> Vec { + let mut md5 = Md5::new(); + // Second pass + md5.update(hash); md5.update(salt); let mut password = format!("md5{:x}", md5.finalize()) @@ -247,6 +253,20 @@ where write_all(stream, message).await } +pub async fn md5_password_with_hash(stream: &mut S, hash: &str, salt: &[u8]) -> Result<(), Error> +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ + let password = md5_hash_second_pass(hash, salt); + let mut message = BytesMut::with_capacity(password.len() as usize + 5); + + message.put_u8(b'p'); + message.put_i32(password.len() as i32 + 4); + message.put_slice(&password[..]); + + write_all(stream, message).await +} + /// Implements a response to our custom `SET SHARDING KEY` /// and `SET SERVER ROLE` commands. /// This tells the client we're ready for the next query. diff --git a/src/mirrors.rs b/src/mirrors.rs index 128fe220..17f91d4d 100644 --- a/src/mirrors.rs +++ b/src/mirrors.rs @@ -4,6 +4,7 @@ use std::sync::Arc; /// Packets arrive to us through a channel from the main client and we send them to the server. use bb8::Pool; use bytes::{Bytes, BytesMut}; +use parking_lot::RwLock; use crate::config::{get_config, Address, Role, User}; use crate::pool::{ClientServerMap, PoolIdentifier, ServerPool}; @@ -41,6 +42,7 @@ impl MirroredClient { self.database.as_str(), ClientServerMap::default(), Arc::new(PoolStats::new(identifier, cfg.clone())), + Arc::new(RwLock::new(None)), ); Pool::builder() diff --git a/src/pool.rs b/src/pool.rs index f6f9118b..e1ab7cb4 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -20,6 +20,7 @@ use tokio::sync::Notify; use crate::config::{get_config, Address, General, LoadBalancingMode, PoolMode, Role, User}; use crate::errors::Error; +use crate::auth_passthrough::AuthPassthrough; use crate::server::Server; use crate::sharding::ShardingFunction; use crate::stats::{AddressStats, ClientStats, PoolStats, ServerStats}; @@ -123,6 +124,11 @@ pub struct PoolSettings { // Limit how much of each query is searched for a potential shard regex match pub regex_search_limit: usize, + + // Auth query parameters + pub auth_query: Option, + pub auth_query_user: Option, + pub auth_query_password: Option, } impl Default for PoolSettings { @@ -143,6 +149,9 @@ impl Default for PoolSettings { sharding_key_regex: None, shard_id_regex: None, regex_search_limit: 1000, + auth_query: None, + auth_query_user: None, + auth_query_password: None, } } } @@ -183,6 +192,9 @@ pub struct ConnectionPool { paused_waiter: Arc, pub stats: Arc, + + /// AuthInfo + pub auth_hash: Arc>>, } impl ConnectionPool { @@ -237,6 +249,7 @@ impl ConnectionPool { // Sort by shard number to ensure consistency. shard_ids.sort_by_key(|k| k.parse::().unwrap()); + let pool_auth_hash: Arc>> = Arc::new(RwLock::new(None)); for shard_idx in &shard_ids { let shard = &pool_config.shards[shard_idx]; @@ -293,12 +306,35 @@ impl ConnectionPool { replica_number += 1; } + // We assume every server in the pool share user/passwords + let auth_passthrough = AuthPassthrough::from_pool_config(pool_config); + + if let Some(apt) = &auth_passthrough { + match apt.fetch_hash(&address).await { + Ok(ok) => { + if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) { + if ok != *pool_auth_hash_value { + warn!("Hash is not the same across shards of the same pool, client auth will \ + be done using last obtained hash. Server: {}:{}, Database: {}", server.host, server.port, shard.database); + } + } + debug!("Hash obtained for {:?}", address); + { + let mut pool_auth_hash = pool_auth_hash.write(); + *pool_auth_hash = Some(ok.clone()); + } + }, + Err(err) => warn!("Could not obtain password hashes using auth_query config, ignoring. Error: {:?}", err), + } + } + let manager = ServerPool::new( address.clone(), user.clone(), &shard.database, client_server_map.clone(), pool_stats.clone(), + pool_auth_hash.clone(), ); let connect_timeout = match pool_config.connect_timeout { @@ -330,6 +366,12 @@ impl ConnectionPool { } assert_eq!(shards.len(), addresses.len()); + if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) { + info!( + "Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}", + pool_name, user.username + ); + } let pool = ConnectionPool { databases: shards, @@ -338,6 +380,7 @@ impl ConnectionPool { banlist: Arc::new(RwLock::new(banlist)), config_hash: new_pool_hash_value, server_info: Arc::new(RwLock::new(BytesMut::new())), + auth_hash: pool_auth_hash, settings: PoolSettings { pool_mode: pool_config.pool_mode, load_balancing_mode: pool_config.load_balancing_mode, @@ -366,6 +409,9 @@ impl ConnectionPool { .clone() .map(|regex| Regex::new(regex.as_str()).unwrap()), regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000), + auth_query: pool_config.auth_query.clone(), + auth_query_user: pool_config.auth_query_user.clone(), + auth_query_password: pool_config.auth_query_password.clone(), }, validated: Arc::new(AtomicBool::new(false)), paused: Arc::new(AtomicBool::new(false)), @@ -389,7 +435,8 @@ impl ConnectionPool { Ok(()) } - /// Connect to all shards and grab server information. + /// Connect to all shards, grab server information, and possibly + /// passwords to use in client auth. /// Return server information we will pass to the clients /// when they connect. /// This also warms up the pool for clients that connect when @@ -803,6 +850,7 @@ pub struct ServerPool { database: String, client_server_map: ClientServerMap, stats: Arc, + auth_hash: Arc>>, } impl ServerPool { @@ -812,6 +860,7 @@ impl ServerPool { database: &str, client_server_map: ClientServerMap, stats: Arc, + auth_hash: Arc>>, ) -> ServerPool { ServerPool { address, @@ -819,6 +868,7 @@ impl ServerPool { database: database.to_string(), client_server_map, stats, + auth_hash, } } } @@ -847,6 +897,7 @@ impl ManageConnection for ServerPool { &self.database, self.client_server_map.clone(), stats.clone(), + self.auth_hash.clone(), ) .await { diff --git a/src/query_router.rs b/src/query_router.rs index 578c7390..0ea907b5 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -1110,6 +1110,9 @@ mod test { sharding_key_regex: None, shard_id_regex: None, regex_search_limit: 1000, + auth_query: None, + auth_query_password: None, + auth_query_user: None, }; let mut qr = QueryRouter::new(); assert_eq!(qr.active_role, None); @@ -1171,6 +1174,9 @@ mod test { sharding_key_regex: Some(Regex::new(r"/\* sharding_key: (\d+) \*/").unwrap()), shard_id_regex: Some(Regex::new(r"/\* shard_id: (\d+) \*/").unwrap()), regex_search_limit: 1000, + auth_query: None, + auth_query_password: None, + auth_query_user: None, }; let mut qr = QueryRouter::new(); qr.update_pool_settings(pool_settings.clone()); diff --git a/src/server.rs b/src/server.rs index d09313ec..37f0e0c7 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,7 +1,11 @@ /// Implementation of the PostgreSQL server (database) protocol. /// Here we are pretending to the a Postgres client. use bytes::{Buf, BufMut, BytesMut}; +use fallible_iterator::FallibleIterator; use log::{debug, error, info, trace, warn}; +use parking_lot::{Mutex, RwLock}; +use postgres_protocol::message; +use std::collections::HashMap; use std::io::Read; use std::sync::Arc; use std::time::SystemTime; @@ -81,6 +85,7 @@ impl Server { database: &str, client_server_map: ClientServerMap, stats: Arc, + auth_hash: Arc>>, ) -> Result { let mut stream = match TcpStream::connect(&format!("{}:{}", &address.host, address.port)).await { @@ -106,7 +111,10 @@ impl Server { // We'll be handling multiple packets, but they will all be structured the same. // We'll loop here until this exchange is complete. - let mut scram = ScramSha256::new(&user.password); + let mut scram: Option = None; + if let Some(password) = &user.password.clone() { + scram = Some(ScramSha256::new(password)); + } loop { let code = match stream.read_u8().await { @@ -143,13 +151,40 @@ impl Server { Err(_) => return Err(Error::SocketError(format!("Error reading salt on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), }; - md5_password(&mut stream, &user.username, &user.password, &salt[..]) - .await?; + match &user.password { + // Using plaintext password + Some(password) => { + md5_password(&mut stream, &user.username, password, &salt[..]) + .await? + } + + // Using auth passthrough, in this case we should already have a + // hash obtained when the pool was validated. If we reach this point + // and don't have a hash, we return an error. + None => { + let option_hash = (*auth_hash.read()).clone(); + match option_hash { + Some(hash) => + md5_password_with_hash( + &mut stream, + &hash, + &salt[..], + ) + .await?, + None => + return Err(Error::AuthError(format!("Auth passthrough (auth_query) failed and no user password is set in cleartext for {{ username: {:?}, database: {:?} }}", user.username, database))) + } + } + } } AUTHENTICATION_SUCCESSFUL => (), SASL => { + if scram.is_none() { + return Err(Error::AuthError(format!("SASL auth required and not password specified, auth passthrough (auth_query) method is currently unsupported for SASL auth {{ username: {:?}, database: {:?} }}", user.username, database))); + } + debug!("Starting SASL authentication"); let sasl_len = (len - 8) as usize; let mut sasl_auth = vec![0u8; sasl_len]; @@ -165,7 +200,7 @@ impl Server { debug!("Using {}", SCRAM_SHA_256); // Generate client message. - let sasl_response = scram.message(); + let sasl_response = scram.as_mut().unwrap().message(); // SASLInitialResponse (F) let mut res = BytesMut::new(); @@ -202,7 +237,7 @@ impl Server { }; let msg = BytesMut::from(&sasl_data[..]); - let sasl_response = scram.update(&msg)?; + let sasl_response = scram.as_mut().unwrap().update(&msg)?; // SASLResponse let mut res = BytesMut::new(); @@ -222,7 +257,11 @@ impl Server { Err(_) => return Err(Error::SocketError(format!("Error reading sasl final message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), }; - match scram.finish(&BytesMut::from(&sasl_final[..])) { + match scram + .as_mut() + .unwrap() + .finish(&BytesMut::from(&sasl_final[..])) + { Ok(_) => { debug!("SASL authentication successful"); } @@ -696,6 +735,105 @@ impl Server { None => (), } } + + // This is so we can execute out of band queries to the server. + // The connection will be opened, the query executed and closed. + pub async fn exec_simple_query( + address: &Address, + user: &User, + query: &str, + ) -> Result, Error> { + let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new())); + + debug!("Connecting to server to obtain auth hashes."); + let mut server = Server::startup( + address, + user, + &address.database, + client_server_map, + Arc::new(ServerStats::default()), + Arc::new(RwLock::new(None)), + ) + .await?; + debug!("Connected!, sending query."); + server.send(&simple_query(query)).await?; + let mut message = server.recv().await?; + + Ok(parse_query_message(&mut message).await?) + } +} + +async fn parse_query_message(message: &mut BytesMut) -> Result, Error> { + let mut pair = Vec::::new(); + match message::backend::Message::parse(message) { + Ok(Some(message::backend::Message::RowDescription(_description))) => {} + Ok(Some(message::backend::Message::ErrorResponse(err))) => { + return Err(Error::ProtocolSyncError(format!( + "Protocol error parsing response. Err: {:?}", + err.fields() + .iterator() + .fold(String::default(), |acc, element| acc + + element.unwrap().value()) + ))) + } + Ok(_) => { + return Err(Error::ProtocolSyncError( + "Protocol error, expected Row Description.".to_string(), + )) + } + Err(err) => { + return Err(Error::ProtocolSyncError(format!( + "Protocol error parsing response. Err: {:?}", + err + ))) + } + } + + while !message.is_empty() { + match message::backend::Message::parse(message) { + Ok(postgres_message) => { + match postgres_message { + Some(message::backend::Message::DataRow(data)) => { + let buf = data.buffer(); + trace!("Data: {:?}", buf); + + for item in data.ranges().iterator() { + match item.as_ref() { + Ok(range) => match range { + Some(range) => { + pair.push(String::from_utf8_lossy(&buf[range.clone()]).to_string()); + } + None => return Err(Error::ProtocolSyncError(String::from( + "Data expected while receiving query auth data, found nothing.", + ))), + }, + Err(err) => { + return Err(Error::ProtocolSyncError(format!( + "Data error, err: {:?}", + err + ))) + } + } + } + } + Some(message::backend::Message::CommandComplete(_)) => {} + Some(message::backend::Message::ReadyForQuery(_)) => {} + _ => { + return Err(Error::ProtocolSyncError( + "Unexpected message while receiving auth query data.".to_string(), + )) + } + } + } + Err(err) => { + return Err(Error::ProtocolSyncError(format!( + "Parse error, err: {:?}", + err + ))) + } + }; + } + Ok(pair) } impl Drop for Server { diff --git a/tests/docker/docker-compose.yml b/tests/docker/docker-compose.yml index e57d8529..93e94550 100644 --- a/tests/docker/docker-compose.yml +++ b/tests/docker/docker-compose.yml @@ -36,6 +36,15 @@ services: POSTGRES_PASSWORD: postgres POSTGRES_INITDB_ARGS: --auth-local=scram-sha-256 --auth-host=scram-sha-256 --auth=scram-sha-256 command: ["postgres", "-p", "9432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"] + pg5: + image: postgres:14 + network_mode: "service:main" + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_INITDB_ARGS: --auth-local=md5 --auth-host=md5 --auth=md5 + command: ["postgres", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-p", "10432"] main: build: . command: ["bash", "/app/tests/docker/run.sh"] diff --git a/tests/ruby/auth_query_spec.rb b/tests/ruby/auth_query_spec.rb new file mode 100644 index 00000000..1ac62164 --- /dev/null +++ b/tests/ruby/auth_query_spec.rb @@ -0,0 +1,215 @@ +# frozen_string_literal: true + +require_relative 'spec_helper' +require_relative 'helpers/auth_query_helper' + +describe "Auth Query" do + let(:configured_instances) {[5432, 10432]} + let(:config_user) { { 'username' => 'sharding_user', 'password' => 'sharding_user' } } + let(:pg_user) { { 'username' => 'sharding_user', 'password' => 'sharding_user' } } + let(:processes) { Helpers::AuthQuery.single_shard_auth_query(pool_name: "sharded_db", pg_user: pg_user, config_user: config_user, extra_conf: config, wait_until_ready: wait_until_ready ) } + let(:config) { {} } + let(:wait_until_ready) { true } + + after do + unless @failing_process + processes.all_databases.map(&:reset) + processes.pgcat.shutdown + end + @failing_process = false + end + + context "when auth_query is not configured" do + context 'and cleartext passwords are set' do + it "uses local passwords" do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", config_user['username'], config_user['password'])) + + expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil + end + end + + context 'and cleartext passwords are not set' do + let(:config_user) { { 'username' => 'sharding_user' } } + + it "does not start because it is not possible to authenticate" do + @failing_process = true + expect { processes.pgcat }.to raise_error(StandardError, /You have to specify a user password for every pool if auth_query is not specified/) + end + end + end + + context 'when auth_query is configured' do + context 'with global configuration' do + around(:example) do |example| + + # Set up auth query + Helpers::AuthQuery.set_up_auth_query_for_user( + user: 'md5_auth_user', + password: 'secret' + ); + + example.run + + # Drop auth query support + Helpers::AuthQuery.tear_down_auth_query_for_user( + user: 'md5_auth_user', + password: 'secret' + ); + end + + context 'with correct global parameters' do + let(:config) { { 'general' => { 'auth_query' => "SELECT * FROM public.user_lookup('$1');", 'auth_query_user' => 'md5_auth_user', 'auth_query_password' => 'secret' } } } + context 'and with cleartext passwords set' do + it 'it uses local passwords' do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password'])) + expect(conn.exec("SELECT 1 + 2")).not_to be_nil + end + end + + context 'and with cleartext passwords not set' do + let(:config_user) { { 'username' => 'sharding_user', 'password' => 'sharding_user' } } + + it 'it uses obtained passwords' do + connection_string = processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password']) + conn = PG.connect(connection_string) + expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil + end + + it 'allows passwords to be changed without closing existing connections' do + pgconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'])) + expect(pgconn.exec("SELECT 1 + 2")).not_to be_nil + Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD 'secret2';") + expect(pgconn.exec("SELECT 1 + 4")).not_to be_nil + Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD '#{pg_user['password']}';") + end + + it 'allows passwords to be changed and that new password is needed when reconnecting' do + pgconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'])) + expect(pgconn.exec("SELECT 1 + 2")).not_to be_nil + Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD 'secret2';") + newconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], 'secret2')) + expect(newconn.exec("SELECT 1 + 2")).not_to be_nil + Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD '#{pg_user['password']}';") + end + end + end + + context 'with wrong parameters' do + let(:config) { { 'general' => { 'auth_query' => 'SELECT 1', 'auth_query_user' => 'wrong_user', 'auth_query_password' => 'wrong' } } } + + context 'and with clear text passwords set' do + it "it uses local passwords" do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password'])) + + expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil + end + end + + context 'and with cleartext passwords not set' do + let(:config_user) { { 'username' => 'sharding_user' } } + it "it fails to start as it cannot authenticate against servers" do + @failing_process = true + expect { PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password'])) }.to raise_error(StandardError, /Error trying to obtain password from auth_query/ ) + end + + context 'and we fix the issue and reload' do + let(:wait_until_ready) { false } + + it 'fails in the beginning but starts working after reloading config' do + connection_string = processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password']) + while !(processes.pgcat.logs =~ /Waiting for clients/) do + sleep 0.5 + end + + expect { PG.connect(connection_string)}.to raise_error(PG::ConnectionBad) + expect(processes.pgcat.logs).to match(/Error trying to obtain password from auth_query/) + + current_config = processes.pgcat.current_config + config = { 'general' => { 'auth_query' => "SELECT * FROM public.user_lookup('$1');", 'auth_query_user' => 'md5_auth_user', 'auth_query_password' => 'secret' } } + processes.pgcat.update_config(current_config.deep_merge(config)) + processes.pgcat.reload_config + + conn = nil + expect { conn = PG.connect(connection_string)}.not_to raise_error + expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil + end + end + end + end + end + + context 'with per pool configuration' do + around(:example) do |example| + + # Set up auth query + Helpers::AuthQuery.set_up_auth_query_for_user( + user: 'md5_auth_user', + password: 'secret' + ); + + Helpers::AuthQuery.set_up_auth_query_for_user( + user: 'md5_auth_user1', + password: 'secret', + database: 'shard1' + ); + + example.run + + # Tear down auth query + Helpers::AuthQuery.tear_down_auth_query_for_user( + user: 'md5_auth_user', + password: 'secret' + ); + + Helpers::AuthQuery.tear_down_auth_query_for_user( + user: 'md5_auth_user1', + password: 'secret', + database: 'shard1' + ); + end + + context 'with correct parameters' do + let(:processes) { Helpers::AuthQuery.two_pools_auth_query(pool_names: ["sharded_db0", "sharded_db1"], pg_user: pg_user, config_user: config_user, extra_conf: config ) } + let(:config) { + { 'pools' => + { + 'sharded_db0' => { + 'auth_query' => "SELECT * FROM public.user_lookup('$1');", + 'auth_query_user' => 'md5_auth_user', + 'auth_query_password' => 'secret' + }, + 'sharded_db1' => { + 'auth_query' => "SELECT * FROM public.user_lookup('$1');", + 'auth_query_user' => 'md5_auth_user1', + 'auth_query_password' => 'secret' + }, + } + } + } + + context 'and with cleartext passwords set' do + it 'it uses local passwords' do + conn = PG.connect(processes.pgcat.connection_string("sharded_db0", pg_user['username'], pg_user['password'])) + expect(conn.exec("SELECT 1 + 2")).not_to be_nil + conn = PG.connect(processes.pgcat.connection_string("sharded_db1", pg_user['username'], pg_user['password'])) + expect(conn.exec("SELECT 1 + 2")).not_to be_nil + end + end + + context 'and with cleartext passwords not set' do + let(:config_user) { { 'username' => 'sharding_user' } } + + it 'it uses obtained passwords' do + connection_string = processes.pgcat.connection_string("sharded_db0", pg_user['username'], pg_user['password']) + conn = PG.connect(connection_string) + expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil + connection_string = processes.pgcat.connection_string("sharded_db1", pg_user['username'], pg_user['password']) + conn = PG.connect(connection_string) + expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil + end + end + + end + end + end +end diff --git a/tests/ruby/helpers/auth_query_helper.rb b/tests/ruby/helpers/auth_query_helper.rb new file mode 100644 index 00000000..60e85713 --- /dev/null +++ b/tests/ruby/helpers/auth_query_helper.rb @@ -0,0 +1,173 @@ +module Helpers + module AuthQuery + def self.single_shard_auth_query( + pg_user:, + config_user:, + pool_name:, + extra_conf: {}, + log_level: 'debug', + wait_until_ready: true + ) + + user = { + "pool_size" => 10, + "statement_timeout" => 0, + } + + pgcat = PgcatProcess.new(log_level) + pgcat_cfg = pgcat.current_config.deep_merge(extra_conf) + + primary = PgInstance.new(5432, pg_user["username"], pg_user["password"], "shard0") + replica = PgInstance.new(10432, pg_user["username"], pg_user["password"], "shard0") + + # Main proxy configs + pgcat_cfg["pools"] = { + "#{pool_name}" => { + "default_role" => "any", + "pool_mode" => "transaction", + "load_balancing_mode" => "random", + "primary_reads_enabled" => false, + "query_parser_enabled" => false, + "sharding_function" => "pg_bigint_hash", + "shards" => { + "0" => { + "database" => "shard0", + "servers" => [ + ["localhost", primary.port.to_s, "primary"], + ["localhost", replica.port.to_s, "replica"], + ] + }, + }, + "users" => { "0" => user.merge(config_user) } + } + } + pgcat_cfg["general"]["port"] = pgcat.port + pgcat.update_config(pgcat_cfg) + pgcat.start + + pgcat.wait_until_ready( + pgcat.connection_string( + "sharded_db", + pg_user['username'], + pg_user['password'] + ) + ) if wait_until_ready + + OpenStruct.new.tap do |struct| + struct.pgcat = pgcat + struct.primary = primary + struct.replicas = [replica] + struct.all_databases = [primary] + end + end + + def self.two_pools_auth_query( + pg_user:, + config_user:, + pool_names:, + extra_conf: {}, + log_level: 'debug' + ) + + user = { + "pool_size" => 10, + "statement_timeout" => 0, + } + + pgcat = PgcatProcess.new(log_level) + pgcat_cfg = pgcat.current_config + + primary = PgInstance.new(5432, pg_user["username"], pg_user["password"], "shard0") + replica = PgInstance.new(10432, pg_user["username"], pg_user["password"], "shard0") + + pool_template = Proc.new do |database| + { + "default_role" => "any", + "pool_mode" => "transaction", + "load_balancing_mode" => "random", + "primary_reads_enabled" => false, + "query_parser_enabled" => false, + "sharding_function" => "pg_bigint_hash", + "shards" => { + "0" => { + "database" => database, + "servers" => [ + ["localhost", primary.port.to_s, "primary"], + ["localhost", replica.port.to_s, "replica"], + ] + }, + }, + "users" => { "0" => user.merge(config_user) } + } + end + # Main proxy configs + pgcat_cfg["pools"] = { + "#{pool_names[0]}" => pool_template.call("shard0"), + "#{pool_names[1]}" => pool_template.call("shard1") + } + + pgcat_cfg["general"]["port"] = pgcat.port + pgcat.update_config(pgcat_cfg.deep_merge(extra_conf)) + pgcat.start + + pgcat.wait_until_ready(pgcat.connection_string("sharded_db0", pg_user['username'], pg_user['password'])) + + OpenStruct.new.tap do |struct| + struct.pgcat = pgcat + struct.primary = primary + struct.replicas = [replica] + struct.all_databases = [primary] + end + end + + def self.create_query_auth_function(user) + return <<-SQL +CREATE OR REPLACE FUNCTION public.user_lookup(in i_username text, out uname text, out phash text) +RETURNS record AS $$ +BEGIN + SELECT usename, passwd FROM pg_catalog.pg_shadow + WHERE usename = i_username INTO uname, phash; + RETURN; +END; +$$ LANGUAGE plpgsql SECURITY DEFINER; + +GRANT EXECUTE ON FUNCTION public.user_lookup(text) TO #{user}; +SQL + end + + def self.exec_in_instances(query:, instance_ports: [ 5432, 10432 ], database: 'postgres', user: 'postgres', password: 'postgres') + instance_ports.each do |port| + c = PG.connect("postgres://#{user}:#{password}@localhost:#{port}/#{database}") + c.exec(query) + c.close + end + end + + def self.set_up_auth_query_for_user(user:, password:, instance_ports: [ 5432, 10432 ], database: 'shard0' ) + instance_ports.each do |port| + connection = PG.connect("postgres://postgres:postgres@localhost:#{port}/#{database}") + connection.exec(self.drop_query_auth_function(user)) rescue PG::UndefinedFunction + connection.exec("DROP ROLE #{user}") rescue PG::UndefinedObject + connection.exec("CREATE ROLE #{user} ENCRYPTED PASSWORD '#{password}' LOGIN;") + connection.exec(self.create_query_auth_function(user)) + connection.close + end + end + + def self.tear_down_auth_query_for_user(user:, password:, instance_ports: [ 5432, 10432 ], database: 'shard0' ) + instance_ports.each do |port| + connection = PG.connect("postgres://postgres:postgres@localhost:#{port}/#{database}") + connection.exec(self.drop_query_auth_function(user)) rescue PG::UndefinedFunction + connection.exec("DROP ROLE #{user}") + connection.close + end + end + + def self.drop_query_auth_function(user) + return <<-SQL +REVOKE ALL ON FUNCTION public.user_lookup(text) FROM public, #{user}; +DROP FUNCTION public.user_lookup(in i_username text, out uname text, out phash text); +SQL + end + end +end diff --git a/tests/ruby/helpers/pgcat_helper.rb b/tests/ruby/helpers/pgcat_helper.rb index c4ebab7f..13dc6686 100644 --- a/tests/ruby/helpers/pgcat_helper.rb +++ b/tests/ruby/helpers/pgcat_helper.rb @@ -3,6 +3,13 @@ require_relative 'pgcat_process' require_relative 'pg_instance' +class ::Hash + def deep_merge(second) + merger = proc { |key, v1, v2| Hash === v1 && Hash === v2 ? v1.merge(v2, &merger) : v2 } + self.merge(second, &merger) + end +end + module Helpers module Pgcat def self.three_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mode="random", log_level="info") diff --git a/tests/ruby/helpers/pgcat_process.rb b/tests/ruby/helpers/pgcat_process.rb index 6120c99f..e1dbea8b 100644 --- a/tests/ruby/helpers/pgcat_process.rb +++ b/tests/ruby/helpers/pgcat_process.rb @@ -67,17 +67,21 @@ def reload_config def start raise StandardError, "Process is already started" unless @pid.nil? @pid = Process.spawn(@env, @command, err: @log_filename, out: @log_filename) + Process.detach(@pid) ObjectSpace.define_finalizer(@log_filename, proc { PgcatProcess.finalize(@pid, @log_filename, @config_filename) }) return self end - def wait_until_ready + def wait_until_ready(connection_string = nil) exc = nil 10.times do - PG::connect(example_connection_string).close + Process.kill 0, @pid + PG::connect(connection_string || example_connection_string).close return self + rescue Errno::ESRCH + raise StandardError, "Process #{@pid} died. #{logs}" rescue => e exc = e sleep(0.5) @@ -108,13 +112,10 @@ def admin_connection_string "postgresql://#{username}:#{password}@0.0.0.0:#{@port}/pgcat" end - def connection_string(pool_name, username) + def connection_string(pool_name, username, password = nil) cfg = current_config - user_idx, user_obj = cfg["pools"][pool_name]["users"].detect { |k, user| user["username"] == username } - password = user_obj["password"] - - "postgresql://#{username}:#{password}@0.0.0.0:#{@port}/#{pool_name}" + "postgresql://#{username}:#{password || user_obj["password"]}@0.0.0.0:#{@port}/#{pool_name}" end def example_connection_string