diff --git a/src/client.rs b/src/client.rs index cc912191..9d4f4038 100644 --- a/src/client.rs +++ b/src/client.rs @@ -11,7 +11,7 @@ use crate::config::get_config; use crate::constants::*; use crate::errors::Error; use crate::messages::*; -use crate::pool::{get_pool, ClientServerMap, ConnectionPool}; +use crate::pool::{get_pool, ClientServerMap}; use crate::query_router::{Command, QueryRouter}; use crate::server::Server; use crate::stats::{get_reporter, Reporter}; @@ -73,8 +73,13 @@ pub struct Client { /// Last server process id we talked to. last_server_id: Option, - target_pool: ConnectionPool, + /// Name of the server pool for this client (This comes from the database name in the connection string) + target_pool_name: String, + /// Postgres user for this client (This comes from the user in the connection string) + target_user_name: String, + + /// Used to notify clients about an impending shutdown shutdown_event_receiver: Receiver<()>, } @@ -305,19 +310,19 @@ where trace!("Got StartupMessage"); let parameters = parse_startup(bytes.clone())?; - let database = match parameters.get("database") { + let target_pool_name = match parameters.get("database") { Some(db) => db, None => return Err(Error::ClientError), }; - let user = match parameters.get("user") { + let target_user_name = match parameters.get("user") { Some(user) => user, None => return Err(Error::ClientError), }; let admin = ["pgcat", "pgbouncer"] .iter() - .filter(|db| *db == &database) + .filter(|db| *db == &target_pool_name) .count() == 1; @@ -352,7 +357,7 @@ where Err(_) => return Err(Error::SocketError), }; - let (target_pool, transaction_mode, server_info) = if admin { + let (transaction_mode, server_info) = if admin { let correct_user = config.general.admin_username.as_str(); let correct_password = config.general.admin_password.as_str(); @@ -360,23 +365,20 @@ where let password_hash = md5_hash_password(correct_user, correct_password, &salt); if password_hash != password_response { debug!("Password authentication failed"); - wrong_password(&mut write, user).await?; + wrong_password(&mut write, target_user_name).await?; return Err(Error::ClientError); } - ( - ConnectionPool::default(), - false, - generate_server_info_for_admin(), - ) + + (false, generate_server_info_for_admin()) } else { - let target_pool = match get_pool(database.clone(), user.clone()) { + let target_pool = match get_pool(target_pool_name.clone(), target_user_name.clone()) { Some(pool) => pool, None => { error_response( &mut write, &format!( "No pool configured for database: {:?}, user: {:?}", - database, user + target_pool_name, target_user_name ), ) .await?; @@ -387,14 +389,14 @@ where let server_info = target_pool.server_info(); // Compare server and client hashes. let correct_password = target_pool.settings.user.password.as_str(); - let password_hash = md5_hash_password(user, correct_password, &salt); + let password_hash = md5_hash_password(&target_user_name, correct_password, &salt); if password_hash != password_response { debug!("Password authentication failed"); - wrong_password(&mut write, user).await?; + wrong_password(&mut write, &target_user_name).await?; return Err(Error::ClientError); } - (target_pool, transaction_mode, server_info) + (transaction_mode, server_info) }; debug!("Password authentication successful"); @@ -424,7 +426,8 @@ where admin: admin, last_address_id: None, last_server_id: None, - target_pool: target_pool, + target_pool_name: target_pool_name.clone(), + target_user_name: target_user_name.clone(), shutdown_event_receiver: shutdown_event_receiver, }); } @@ -455,7 +458,8 @@ where admin: false, last_address_id: None, last_server_id: None, - target_pool: ConnectionPool::default(), + target_pool_name: String::from("undefined"), + target_user_name: String::from("undefined"), shutdown_event_receiver: shutdown_event_receiver, }); } @@ -494,7 +498,7 @@ where // The query router determines where the query is going to go, // e.g. primary, replica, which shard. - let mut query_router = QueryRouter::new(self.target_pool.clone()); + let mut query_router = QueryRouter::new(); let mut round_robin = 0; // Our custom protocol loop. @@ -520,11 +524,6 @@ where message_result = read_message(&mut self.read) => message_result? }; - // Get a pool instance referenced by the most up-to-date - // pointer. This ensures we always read the latest config - // when starting a query. - let mut pool = self.target_pool.clone(); - // Avoid taking a server if the client just wants to disconnect. if message[0] as char == 'X' { debug!("Client disconnecting"); @@ -538,6 +537,25 @@ where continue; } + // Get a pool instance referenced by the most up-to-date + // pointer. This ensures we always read the latest config + // when starting a query. + let mut pool = + match get_pool(self.target_pool_name.clone(), self.target_user_name.clone()) { + Some(pool) => pool, + None => { + error_response( + &mut self.write, + &format!( + "No pool configured for database: {:?}, user: {:?}", + self.target_pool_name, self.target_user_name + ), + ) + .await?; + return Err(Error::ClientError); + } + }; + query_router.update_pool_settings(pool.settings.clone()); let current_shard = query_router.shard(); // Handle all custom protocol commands, if any. diff --git a/src/query_router.rs b/src/query_router.rs index d597b81e..6b377684 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -9,7 +9,7 @@ use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::Parser; use crate::config::Role; -use crate::pool::{ConnectionPool, PoolSettings}; +use crate::pool::PoolSettings; use crate::sharding::{Sharder, ShardingFunction}; /// Regexes used to parse custom commands. @@ -91,16 +91,20 @@ impl QueryRouter { } /// Create a new instance of the query router. Each client gets its own. - pub fn new(target_pool: ConnectionPool) -> QueryRouter { + pub fn new() -> QueryRouter { QueryRouter { active_shard: None, active_role: None, - query_parser_enabled: target_pool.settings.query_parser_enabled, - primary_reads_enabled: target_pool.settings.primary_reads_enabled, - pool_settings: target_pool.settings, + query_parser_enabled: false, + primary_reads_enabled: false, + pool_settings: PoolSettings::default(), } } + pub fn update_pool_settings(&mut self, pool_settings: PoolSettings) { + self.pool_settings = pool_settings; + } + /// Try to parse a command and execute it. pub fn try_execute_command(&mut self, mut buf: BytesMut) -> Option<(Command, String)> { let code = buf.get_u8() as char; @@ -363,6 +367,8 @@ impl QueryRouter { #[cfg(test)] mod test { + use std::collections::HashMap; + use super::*; use crate::messages::simple_query; use bytes::BufMut; @@ -370,7 +376,7 @@ mod test { #[test] fn test_defaults() { QueryRouter::setup(); - let qr = QueryRouter::new(ConnectionPool::default()); + let qr = QueryRouter::new(); assert_eq!(qr.role(), None); } @@ -378,7 +384,7 @@ mod test { #[test] fn test_infer_role_replica() { QueryRouter::setup(); - let mut qr = QueryRouter::new(ConnectionPool::default()); + let mut qr = QueryRouter::new(); assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None); assert_eq!(qr.query_parser_enabled(), true); @@ -402,7 +408,7 @@ mod test { #[test] fn test_infer_role_primary() { QueryRouter::setup(); - let mut qr = QueryRouter::new(ConnectionPool::default()); + let mut qr = QueryRouter::new(); let queries = vec![ simple_query("UPDATE items SET name = 'pumpkin' WHERE id = 5"), @@ -421,7 +427,7 @@ mod test { #[test] fn test_infer_role_primary_reads_enabled() { QueryRouter::setup(); - let mut qr = QueryRouter::new(ConnectionPool::default()); + let mut qr = QueryRouter::new(); let query = simple_query("SELECT * FROM items WHERE id = 5"); assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO on")) != None); @@ -432,7 +438,7 @@ mod test { #[test] fn test_infer_role_parse_prepared() { QueryRouter::setup(); - let mut qr = QueryRouter::new(ConnectionPool::default()); + let mut qr = QueryRouter::new(); qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")); assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); @@ -523,7 +529,7 @@ mod test { #[test] fn test_try_execute_command() { QueryRouter::setup(); - let mut qr = QueryRouter::new(ConnectionPool::default()); + let mut qr = QueryRouter::new(); // SetShardingKey let query = simple_query("SET SHARDING KEY TO 13"); @@ -600,7 +606,7 @@ mod test { #[test] fn test_enable_query_parser() { QueryRouter::setup(); - let mut qr = QueryRouter::new(ConnectionPool::default()); + let mut qr = QueryRouter::new(); let query = simple_query("SET SERVER ROLE TO 'auto'"); assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); @@ -621,4 +627,43 @@ mod test { assert!(qr.try_execute_command(query) != None); assert!(qr.query_parser_enabled()); } + + #[test] + fn test_update_from_pool_settings() { + QueryRouter::setup(); + + let pool_settings = PoolSettings { + pool_mode: "transaction".to_string(), + shards: HashMap::default(), + user: crate::config::User::default(), + default_role: Role::Replica.to_string(), + query_parser_enabled: true, + primary_reads_enabled: false, + sharding_function: "pg_bigint_hash".to_string(), + }; + let mut qr = QueryRouter::new(); + assert_eq!(qr.active_role, None); + assert_eq!(qr.active_shard, None); + assert_eq!(qr.query_parser_enabled, false); + assert_eq!(qr.primary_reads_enabled, false); + + // Internal state must not be changed due to this, only defaults + qr.update_pool_settings(pool_settings.clone()); + + assert_eq!(qr.active_role, None); + assert_eq!(qr.active_shard, None); + assert_eq!(qr.query_parser_enabled, false); + assert_eq!(qr.primary_reads_enabled, false); + + let q1 = simple_query("SET SERVER ROLE TO 'primary'"); + assert!(qr.try_execute_command(q1) != None); + assert_eq!(qr.active_role.unwrap(), Role::Primary); + + let q2 = simple_query("SET SERVER ROLE TO 'default'"); + assert!(qr.try_execute_command(q2) != None); + assert_eq!( + qr.active_role.unwrap().to_string(), + pool_settings.clone().default_role + ); + } } diff --git a/tests/ruby/.ruby-version b/tests/ruby/.ruby-version index 860487ca..cf232b52 100644 --- a/tests/ruby/.ruby-version +++ b/tests/ruby/.ruby-version @@ -1 +1,2 @@ -2.7.1 +3.0.0 + diff --git a/tests/ruby/Gemfile b/tests/ruby/Gemfile index 05684c98..7b019183 100644 --- a/tests/ruby/Gemfile +++ b/tests/ruby/Gemfile @@ -3,3 +3,4 @@ source "https://rubygems.org" gem "pg" gem "activerecord" gem "rubocop" +gem "toml", "~> 0.3.0" diff --git a/tests/ruby/Gemfile.lock b/tests/ruby/Gemfile.lock index 607df18c..3fd03471 100644 --- a/tests/ruby/Gemfile.lock +++ b/tests/ruby/Gemfile.lock @@ -19,6 +19,7 @@ GEM parallel (1.22.1) parser (3.1.2.0) ast (~> 2.4.1) + parslet (2.0.0) pg (1.3.2) rainbow (3.1.1) regexp_parser (2.3.1) @@ -35,17 +36,21 @@ GEM rubocop-ast (1.17.0) parser (>= 3.1.1.0) ruby-progressbar (1.11.0) + toml (0.3.0) + parslet (>= 1.8.0, < 3.0.0) tzinfo (2.0.4) concurrent-ruby (~> 1.0) unicode-display_width (2.1.0) PLATFORMS + arm64-darwin-21 x86_64-linux DEPENDENCIES activerecord pg rubocop + toml (~> 0.3.0) BUNDLED WITH 2.3.7 diff --git a/tests/ruby/tests.rb b/tests/ruby/tests.rb index c5a55a7e..ba9476f4 100644 --- a/tests/ruby/tests.rb +++ b/tests/ruby/tests.rb @@ -2,6 +2,7 @@ require 'active_record' require 'pg' +require 'toml' $stdout.sync = true @@ -141,3 +142,62 @@ def test_server_parameters puts 'Server parameters ok' end + + +class ConfigEditor + def initialize + @original_config_text = File.read('../../.circleci/pgcat.toml') + text_to_load = @original_config_text.gsub("5432", "\"5432\"") + + @original_configs = TOML.load(text_to_load) + end + + def original_configs + TOML.load(TOML::Generator.new(@original_configs).body) + end + + def with_modified_configs(new_configs) + text_to_write = TOML::Generator.new(new_configs).body + text_to_write = text_to_write.gsub("\"5432\"", "5432") + File.write('../../.circleci/pgcat.toml', text_to_write) + yield + ensure + File.write('../../.circleci/pgcat.toml', @original_config_text) + end + +end + + +def test_reload_pool_recycling + admin_conn = PG::connect("postgres://admin_user:admin_pass@127.0.0.1:6432/pgcat") + server_conn = PG::connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db?application_name=testing_pgcat") + + server_conn.async_exec("BEGIN") + conf_editor = ConfigEditor.new + new_configs = conf_editor.original_configs + + # swap shards + new_configs["pools"]["sharded_db"]["shards"]["0"]["database"] = "shard1" + new_configs["pools"]["sharded_db"]["shards"]["1"]["database"] = "shard0" + + raise StandardError if server_conn.async_exec("SELECT current_database();")[0]["current_database"] != 'shard0' + conf_editor.with_modified_configs(new_configs) { admin_conn.async_exec("RELOAD") } + raise StandardError if server_conn.async_exec("SELECT current_database();")[0]["current_database"] != 'shard0' + server_conn.async_exec("COMMIT;") + + # Transaction finished, client should get new configs + raise StandardError if server_conn.async_exec("SELECT current_database();")[0]["current_database"] != 'shard1' + server_conn.close() + + # New connection should get new configs + server_conn = PG::connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db?application_name=testing_pgcat") + raise StandardError if server_conn.async_exec("SELECT current_database();")[0]["current_database"] != 'shard1' + +ensure + admin_conn.async_exec("RELOAD") # Go back to old state + admin_conn.close + server_conn.close + puts "Pool Recycling okay!" +end + +test_reload_pool_recycling