diff --git a/src/client.rs b/src/client.rs index 278cda8c..0c553f83 100644 --- a/src/client.rs +++ b/src/client.rs @@ -499,7 +499,6 @@ where // The query router determines where the query is going to go, // e.g. primary, replica, which shard. let mut query_router = QueryRouter::new(); - let mut round_robin = rand::random(); // Our custom protocol loop. // We expect the client to either start a transaction with regular queries @@ -631,12 +630,7 @@ where // Grab a server from the pool. let connection = match pool - .get( - query_router.shard(), - query_router.role(), - self.process_id, - round_robin, - ) + .get(query_router.shard(), query_router.role(), self.process_id) .await { Ok(conn) => { @@ -655,8 +649,6 @@ where let address = connection.1; let server = &mut *reference; - round_robin += 1; - // Server is assigned to the client in case the client wants to // cancel a query later. server.claim(self.process_id, self.secret_key); diff --git a/src/config.rs b/src/config.rs index ed338104..9d1658ff 100644 --- a/src/config.rs +++ b/src/config.rs @@ -63,7 +63,7 @@ pub struct Address { pub shard: usize, pub database: String, pub role: Role, - pub replica_number: usize, + pub instance_index: usize, pub username: String, pub poolname: String, } @@ -75,7 +75,7 @@ impl Default for Address { host: String::from("127.0.0.1"), port: String::from("5432"), shard: 0, - replica_number: 0, + instance_index: 0, database: String::from("database"), role: Role::Replica, username: String::from("username"), @@ -92,7 +92,7 @@ impl Address { Role::Replica => format!( "{}_shard_{}_replica_{}", - self.poolname, self.shard, self.replica_number + self.poolname, self.shard, self.instance_index ), } } diff --git a/src/pool.rs b/src/pool.rs index 5684d545..cbb9b43f 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -6,6 +6,8 @@ use chrono::naive::NaiveDateTime; use log::{debug, error, info, warn}; use once_cell::sync::Lazy; use parking_lot::{Mutex, RwLock}; +use rand::seq::SliceRandom; +use rand::thread_rng; use std::collections::HashMap; use std::sync::Arc; use std::time::Instant; @@ -118,7 +120,7 @@ impl ConnectionPool { host: server.0.clone(), port: server.1.to_string(), role: role, - replica_number, + instance_index: replica_number, shard: shard_idx.parse::().unwrap(), username: user_info.username.clone(), poolname: pool_name.clone(), @@ -201,16 +203,9 @@ impl ConnectionPool { /// the pooler starts up. async fn validate(&mut self) -> Result<(), Error> { let mut server_infos = Vec::new(); - let stats = self.stats.clone(); - for shard in 0..self.shards() { - let mut round_robin = 0; - - for _ in 0..self.servers(shard) { - // To keep stats consistent. - let fake_process_id = 0; - - let connection = match self.get(shard, None, fake_process_id, round_robin).await { + for index in 0..self.servers(shard) { + let connection = match self.databases[shard][index].get().await { Ok(conn) => conn, Err(err) => { error!("Shard {} down or misconfigured: {:?}", shard, err); @@ -218,25 +213,20 @@ impl ConnectionPool { } }; - let proxy = connection.0; - let address = connection.1; + let proxy = connection; let server = &*proxy; let server_info = server.server_info(); - stats.client_disconnecting(fake_process_id, address.id); - if server_infos.len() > 0 { // Compare against the last server checked. if server_info != server_infos[server_infos.len() - 1] { warn!( "{:?} has different server configuration than the last server", - address + proxy.address() ); } } - server_infos.push(server_info); - round_robin += 1; } } @@ -254,58 +244,31 @@ impl ConnectionPool { /// Get a connection from the pool. pub async fn get( &self, - shard: usize, // shard number - role: Option, // primary or replica - process_id: i32, // client id - mut round_robin: usize, // round robin offset + shard: usize, // shard number + role: Option, // primary or replica + process_id: i32, // client id ) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> { let now = Instant::now(); - let addresses = &self.addresses[shard]; - - let mut allowed_attempts = match role { - // Primary-specific queries get one attempt, if the primary is down, - // nothing we should do about it I think. It's dangerous to retry - // write queries. - Some(Role::Primary) => 1, + let mut candidates: Vec
= self.addresses[shard] + .clone() + .into_iter() + .filter(|address| address.role == role) + .collect(); - // Replicas get to try as many times as there are replicas - // and connections in the pool. - _ => addresses.len(), - }; - - debug!("Allowed attempts for {:?}: {}", role, allowed_attempts); - - let exists = match role { - Some(role) => addresses.iter().filter(|addr| addr.role == role).count() > 0, - None => true, - }; - - if !exists { - error!("Requested role {:?}, but none are configured", role); - return Err(Error::BadConfig); - } + // Random load balancing + candidates.shuffle(&mut thread_rng()); let healthcheck_timeout = get_config().general.healthcheck_timeout; let healthcheck_delay = get_config().general.healthcheck_delay as u128; - while allowed_attempts > 0 { - // Round-robin replicas. - round_robin += 1; - - let index = round_robin % addresses.len(); - let address = &addresses[index]; - - // Make sure you're getting a primary or a replica - // as per request. If no specific role is requested, the first - // available will be chosen. - if address.role != role { - continue; - } - - allowed_attempts -= 1; + while !candidates.is_empty() { + // Get the next candidate + let address = match candidates.pop() { + Some(address) => address, + None => break, + }; - // Don't attempt to connect to banned servers. - if self.is_banned(address, shard, role) { + if self.is_banned(&address, address.shard, role) { continue; } @@ -313,11 +276,14 @@ impl ConnectionPool { self.stats.client_waiting(process_id, address.id); // Check if we can connect - let mut conn = match self.databases[shard][index].get().await { + let mut conn = match self.databases[address.shard][address.instance_index] + .get() + .await + { Ok(conn) => conn, Err(err) => { - error!("Banning replica {}, error: {:?}", index, err); - self.ban(address, shard, process_id); + error!("Banning instance {:?}, error: {:?}", address, err); + self.ban(&address, address.shard, process_id); self.stats.client_disconnecting(process_id, address.id); self.stats .checkout_time(now.elapsed().as_micros(), process_id, address.id); @@ -359,29 +325,34 @@ impl ConnectionPool { } // Health check failed. - Err(_) => { - error!("Banning replica {} because of failed health check", index); + Err(err) => { + error!( + "Banning instance {:?} because of failed health check, {:?}", + address, err + ); // Don't leave a bad connection in the pool. server.mark_bad(); - self.ban(address, shard, process_id); + self.ban(&address, address.shard, process_id); continue; } }, // Health check timed out. - Err(_) => { - error!("Banning replica {} because of health check timeout", index); + Err(err) => { + error!( + "Banning instance {:?} because of health check timeout, {:?}", + address, err + ); // Don't leave a bad connection in the pool. server.mark_bad(); - self.ban(address, shard, process_id); + self.ban(&address, address.shard, process_id); continue; } } } - return Err(Error::AllServersDown); }