diff --git a/src/admin.rs b/src/admin.rs index 1aa2bced..d794b86a 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -215,16 +215,16 @@ where let mut res = BytesMut::new(); res.put(row_description(&columns)); - for ((pool_name, username), pool) in get_all_pools() { + for (user_pool, pool) in get_all_pools() { let def = HashMap::default(); let pool_stats = all_pool_stats - .get(&(pool_name.clone(), username.clone())) + .get(&(user_pool.db.clone(), user_pool.user.clone())) .unwrap_or(&def); let pool_config = &pool.settings; let mut row = vec![ - pool_name.clone(), - username.clone(), + user_pool.db.clone(), + user_pool.user.clone(), pool_config.pool_mode.to_string(), ]; for column in &columns[3..columns.len()] { @@ -420,7 +420,7 @@ where let mut res = BytesMut::new(); res.put(row_description(&columns)); - for ((db, username), pool) in get_all_pools() { + for (user_pool, pool) in get_all_pools() { for shard in 0..pool.shards() { for server in 0..pool.servers(shard) { let address = pool.address(shard, server); @@ -429,7 +429,7 @@ where None => HashMap::new(), }; - let mut row = vec![address.name(), db.clone(), username.clone()]; + let mut row = vec![address.name(), user_pool.db.clone(), user_pool.user.clone()]; for column in &columns[3..] { row.push(stats.get(column.0).unwrap_or(&0).to_string()); } diff --git a/src/client.rs b/src/client.rs index cdef30f6..64dfa8ac 100644 --- a/src/client.rs +++ b/src/client.rs @@ -446,7 +446,7 @@ where } // Authenticate normal user. else { - let pool = match get_pool(pool_name.clone(), username.clone()) { + let pool = match get_pool(&pool_name, &username) { Some(pool) => pool, None => { error_response( @@ -648,7 +648,7 @@ where // Get a pool instance referenced by the most up-to-date // pointer. This ensures we always read the latest config // when starting a query. - let pool = match get_pool(self.pool_name.clone(), self.username.clone()) { + let pool = match get_pool(&self.pool_name, &self.username) { Some(pool) => pool, None => { error_response( diff --git a/src/pool.rs b/src/pool.rs index 815a2b8b..bb452537 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -19,9 +19,15 @@ use crate::server::Server; use crate::sharding::ShardingFunction; use crate::stats::{get_reporter, Reporter}; +pub type ProcessId = i32; +pub type SecretKey = i32; +pub type ServerHost = String; +pub type ServerPort = u16; + pub type BanList = Arc>>>; -pub type ClientServerMap = Arc>>; -pub type PoolMap = HashMap<(String, String), ConnectionPool>; +pub type ClientServerMap = + Arc>>; +pub type PoolMap = HashMap; /// The connection pool, globally available. /// This is atomic and safe and read-optimized. /// The pool is recreated dynamically when the config is reloaded. @@ -29,6 +35,27 @@ pub static POOLS: Lazy> = Lazy::new(|| ArcSwap::from_pointee(Ha static POOLS_HASH: Lazy>> = Lazy::new(|| ArcSwap::from_pointee(HashSet::default())); +/// An identifier for a PgCat pool, +/// a database visible to clients. +#[derive(Hash, Debug, Clone, PartialEq, Eq)] +pub struct PoolIdentifier { + // The name of the database clients want to connect to. + pub db: String, + + /// The username the client connects with. Each user gets its own pool. + pub user: String, +} + +impl PoolIdentifier { + /// Create a new user/pool identifier. + pub fn new(db: &str, user: &str) -> PoolIdentifier { + PoolIdentifier { + db: db.to_string(), + user: user.to_string(), + } + } +} + /// Pool settings. #[derive(Clone, Debug)] pub struct PoolSettings { @@ -113,14 +140,16 @@ impl ConnectionPool { // If the pool hasn't changed, get existing reference and insert it into the new_pools. // We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8). if !changed { - match get_pool(pool_name.clone(), user.username.clone()) { + match get_pool(&pool_name, &user.username) { Some(pool) => { info!( "[pool: {}][user: {}] has not changed", pool_name, user.username ); - new_pools - .insert((pool_name.clone(), user.username.clone()), pool.clone()); + new_pools.insert( + PoolIdentifier::new(&pool_name, &user.username), + pool.clone(), + ); continue; } None => (), @@ -239,7 +268,7 @@ impl ConnectionPool { }; // There is one pool per database/user pair. - new_pools.insert((pool_name.clone(), user.username.clone()), pool); + new_pools.insert(PoolIdentifier::new(&pool_name, &user.username), pool); } } @@ -603,15 +632,15 @@ impl ManageConnection for ServerPool { } /// Get the connection pool -pub fn get_pool(db: String, user: String) -> Option { - match get_all_pools().get(&(db, user)) { +pub fn get_pool(db: &str, user: &str) -> Option { + match get_all_pools().get(&PoolIdentifier::new(&db, &user)) { Some(pool) => Some(pool.clone()), None => None, } } /// Get a pointer to all configured pools. -pub fn get_all_pools() -> HashMap<(String, String), ConnectionPool> { +pub fn get_all_pools() -> HashMap { return (*(*POOLS.load())).clone(); } diff --git a/src/stats.rs b/src/stats.rs index e37c88c7..7998e454 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -521,11 +521,11 @@ impl Collector { tokio::time::interval(tokio::time::Duration::from_millis(STAT_PERIOD / 15)); loop { interval.tick().await; - for ((pool_name, username), _pool) in get_all_pools() { + for (user_pool, _) in get_all_pools() { let _ = tx.try_send(Event { name: EventName::UpdateStats { - pool_name, - username, + pool_name: user_pool.db, + username: user_pool.user, }, value: 0, });