Skip to content

Random instance selection #136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Aug 22, 2022
10 changes: 1 addition & 9 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,6 @@ where
// The query router determines where the query is going to go,
// e.g. primary, replica, which shard.
let mut query_router = QueryRouter::new();
let mut round_robin = rand::random();

// Our custom protocol loop.
// We expect the client to either start a transaction with regular queries
Expand Down Expand Up @@ -631,12 +630,7 @@ where

// Grab a server from the pool.
let connection = match pool
.get(
query_router.shard(),
query_router.role(),
self.process_id,
round_robin,
)
.get(query_router.shard(), query_router.role(), self.process_id)
.await
{
Ok(conn) => {
Expand All @@ -655,8 +649,6 @@ where
let address = connection.1;
let server = &mut *reference;

round_robin += 1;

// Server is assigned to the client in case the client wants to
// cancel a query later.
server.claim(self.process_id, self.secret_key);
Expand Down
6 changes: 3 additions & 3 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub struct Address {
pub shard: usize,
pub database: String,
pub role: Role,
pub replica_number: usize,
pub instance_index: usize,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Address could point to a primary so using the term replica is misleading

pub username: String,
pub poolname: String,
}
Expand All @@ -75,7 +75,7 @@ impl Default for Address {
host: String::from("127.0.0.1"),
port: String::from("5432"),
shard: 0,
replica_number: 0,
instance_index: 0,
database: String::from("database"),
role: Role::Replica,
username: String::from("username"),
Expand All @@ -92,7 +92,7 @@ impl Address {

Role::Replica => format!(
"{}_shard_{}_replica_{}",
self.poolname, self.shard, self.replica_number
self.poolname, self.shard, self.instance_index
),
}
}
Expand Down
113 changes: 42 additions & 71 deletions src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use chrono::naive::NaiveDateTime;
use log::{debug, error, info, warn};
use once_cell::sync::Lazy;
use parking_lot::{Mutex, RwLock};
use rand::seq::SliceRandom;
use rand::thread_rng;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
Expand Down Expand Up @@ -118,7 +120,7 @@ impl ConnectionPool {
host: server.0.clone(),
port: server.1.to_string(),
role: role,
replica_number,
instance_index: replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
username: user_info.username.clone(),
poolname: pool_name.clone(),
Expand Down Expand Up @@ -201,42 +203,30 @@ impl ConnectionPool {
/// the pooler starts up.
async fn validate(&mut self) -> Result<(), Error> {
let mut server_infos = Vec::new();
let stats = self.stats.clone();

for shard in 0..self.shards() {
let mut round_robin = 0;

for _ in 0..self.servers(shard) {
// To keep stats consistent.
let fake_process_id = 0;

let connection = match self.get(shard, None, fake_process_id, round_robin).await {
for index in 0..self.servers(shard) {
let connection = match self.databases[shard][index].get().await {
Ok(conn) => conn,
Err(err) => {
error!("Shard {} down or misconfigured: {:?}", shard, err);
continue;
}
};

let proxy = connection.0;
let address = connection.1;
let proxy = connection;
let server = &*proxy;
let server_info = server.server_info();

stats.client_disconnecting(fake_process_id, address.id);

if server_infos.len() > 0 {
// Compare against the last server checked.
if server_info != server_infos[server_infos.len() - 1] {
warn!(
"{:?} has different server configuration than the last server",
address
proxy.address()
);
}
}

server_infos.push(server_info);
round_robin += 1;
}
}

Expand All @@ -254,70 +244,46 @@ impl ConnectionPool {
/// Get a connection from the pool.
pub async fn get(
&self,
shard: usize, // shard number
role: Option<Role>, // primary or replica
process_id: i32, // client id
mut round_robin: usize, // round robin offset
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After randomization, this factor won't be needed. It is only used by the validate method to go over all instances.

To simplify the logic for this method, I refactored validate to checkout connection directly from the underlying pool using shard_index and instance_index directly.

shard: usize, // shard number
role: Option<Role>, // primary or replica
process_id: i32, // client id
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
let now = Instant::now();
let addresses = &self.addresses[shard];

let mut allowed_attempts = match role {
// Primary-specific queries get one attempt, if the primary is down,
// nothing we should do about it I think. It's dangerous to retry
// write queries.
Some(Role::Primary) => 1,
let mut candidates: Vec<Address> = self.addresses[shard]
.clone()
.into_iter()
.filter(|address| address.role == role)
.collect();

// Replicas get to try as many times as there are replicas
// and connections in the pool.
_ => addresses.len(),
};

debug!("Allowed attempts for {:?}: {}", role, allowed_attempts);

let exists = match role {
Some(role) => addresses.iter().filter(|addr| addr.role == role).count() > 0,
None => true,
};

if !exists {
error!("Requested role {:?}, but none are configured", role);
return Err(Error::BadConfig);
}
// Random load balancing
candidates.shuffle(&mut thread_rng());

let healthcheck_timeout = get_config().general.healthcheck_timeout;
let healthcheck_delay = get_config().general.healthcheck_delay as u128;

while allowed_attempts > 0 {
// Round-robin replicas.
round_robin += 1;

let index = round_robin % addresses.len();
let address = &addresses[index];

// Make sure you're getting a primary or a replica
// as per request. If no specific role is requested, the first
// available will be chosen.
if address.role != role {
continue;
}

allowed_attempts -= 1;
while !candidates.is_empty() {
// Get the next candidate
let address = match candidates.pop() {
Some(address) => address,
None => break,
};

// Don't attempt to connect to banned servers.
if self.is_banned(address, shard, role) {
if self.is_banned(&address, address.shard, role) {
continue;
}

// Indicate we're waiting on a server connection from a pool.
self.stats.client_waiting(process_id, address.id);

// Check if we can connect
let mut conn = match self.databases[shard][index].get().await {
let mut conn = match self.databases[address.shard][address.instance_index]
.get()
.await
{
Ok(conn) => conn,
Err(err) => {
error!("Banning replica {}, error: {:?}", index, err);
self.ban(address, shard, process_id);
error!("Banning instance {:?}, error: {:?}", address, err);
self.ban(&address, address.shard, process_id);
self.stats.client_disconnecting(process_id, address.id);
self.stats
.checkout_time(now.elapsed().as_micros(), process_id, address.id);
Expand Down Expand Up @@ -359,29 +325,34 @@ impl ConnectionPool {
}

// Health check failed.
Err(_) => {
error!("Banning replica {} because of failed health check", index);
Err(err) => {
error!(
"Banning instance {:?} because of failed health check, {:?}",
address, err
);

// Don't leave a bad connection in the pool.
server.mark_bad();

self.ban(address, shard, process_id);
self.ban(&address, address.shard, process_id);
continue;
}
},

// Health check timed out.
Err(_) => {
error!("Banning replica {} because of health check timeout", index);
Err(err) => {
error!(
"Banning instance {:?} because of health check timeout, {:?}",
address, err
);
// Don't leave a bad connection in the pool.
server.mark_bad();

self.ban(address, shard, process_id);
self.ban(&address, address.shard, process_id);
continue;
}
}
}

return Err(Error::AllServersDown);
}

Expand Down