Skip to content

Prevent clients from sticking to old pools after config update #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 43 additions & 25 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::config::get_config;
use crate::constants::*;
use crate::errors::Error;
use crate::messages::*;
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
use crate::pool::{get_pool, ClientServerMap};
use crate::query_router::{Command, QueryRouter};
use crate::server::Server;
use crate::stats::{get_reporter, Reporter};
Expand Down Expand Up @@ -73,8 +73,13 @@ pub struct Client<S, T> {
/// Last server process id we talked to.
last_server_id: Option<i32>,

target_pool: ConnectionPool,
/// Name of the server pool for this client (This comes from the database name in the connection string)
target_pool_name: String,

/// Postgres user for this client (This comes from the user in the connection string)
target_user_name: String,

/// Used to notify clients about an impending shutdown
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation

shutdown_event_receiver: Receiver<()>,
}

Expand Down Expand Up @@ -305,19 +310,19 @@ where

trace!("Got StartupMessage");
let parameters = parse_startup(bytes.clone())?;
let database = match parameters.get("database") {
let target_pool_name = match parameters.get("database") {
Some(db) => db,
None => return Err(Error::ClientError),
};

let user = match parameters.get("user") {
let target_user_name = match parameters.get("user") {
Some(user) => user,
None => return Err(Error::ClientError),
};

let admin = ["pgcat", "pgbouncer"]
.iter()
.filter(|db| *db == &database)
.filter(|db| *db == &target_pool_name)
.count()
== 1;

Expand Down Expand Up @@ -352,31 +357,28 @@ where
Err(_) => return Err(Error::SocketError),
};

let (target_pool, transaction_mode, server_info) = if admin {
let (transaction_mode, server_info) = if admin {
let correct_user = config.general.admin_username.as_str();
let correct_password = config.general.admin_password.as_str();

// Compare server and client hashes.
let password_hash = md5_hash_password(correct_user, correct_password, &salt);
if password_hash != password_response {
debug!("Password authentication failed");
wrong_password(&mut write, user).await?;
wrong_password(&mut write, target_user_name).await?;
return Err(Error::ClientError);
}
(
ConnectionPool::default(),
false,
generate_server_info_for_admin(),
)

(false, generate_server_info_for_admin())
} else {
let target_pool = match get_pool(database.clone(), user.clone()) {
let target_pool = match get_pool(target_pool_name.clone(), target_user_name.clone()) {
Some(pool) => pool,
None => {
error_response(
&mut write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
database, user
target_pool_name, target_user_name
),
)
.await?;
Expand All @@ -387,14 +389,14 @@ where
let server_info = target_pool.server_info();
// Compare server and client hashes.
let correct_password = target_pool.settings.user.password.as_str();
let password_hash = md5_hash_password(user, correct_password, &salt);
let password_hash = md5_hash_password(&target_user_name, correct_password, &salt);

if password_hash != password_response {
debug!("Password authentication failed");
wrong_password(&mut write, user).await?;
wrong_password(&mut write, &target_user_name).await?;
return Err(Error::ClientError);
}
(target_pool, transaction_mode, server_info)
(transaction_mode, server_info)
};

debug!("Password authentication successful");
Expand Down Expand Up @@ -424,7 +426,8 @@ where
admin: admin,
last_address_id: None,
last_server_id: None,
target_pool: target_pool,
target_pool_name: target_pool_name.clone(),
target_user_name: target_user_name.clone(),
shutdown_event_receiver: shutdown_event_receiver,
});
}
Expand Down Expand Up @@ -455,7 +458,8 @@ where
admin: false,
last_address_id: None,
last_server_id: None,
target_pool: ConnectionPool::default(),
target_pool_name: String::from("undefined"),
target_user_name: String::from("undefined"),
shutdown_event_receiver: shutdown_event_receiver,
});
}
Expand Down Expand Up @@ -494,7 +498,7 @@ where

// The query router determines where the query is going to go,
// e.g. primary, replica, which shard.
let mut query_router = QueryRouter::new(self.target_pool.clone());
let mut query_router = QueryRouter::new();
let mut round_robin = 0;

// Our custom protocol loop.
Expand All @@ -520,11 +524,6 @@ where
message_result = read_message(&mut self.read) => message_result?
};

// Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config
// when starting a query.
let mut pool = self.target_pool.clone();

// Avoid taking a server if the client just wants to disconnect.
if message[0] as char == 'X' {
debug!("Client disconnecting");
Expand All @@ -538,6 +537,25 @@ where
continue;
}

// Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config
// when starting a query.
let mut pool =
match get_pool(self.target_pool_name.clone(), self.target_user_name.clone()) {
Some(pool) => pool,
None => {
error_response(
&mut self.write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
self.target_pool_name, self.target_user_name
),
)
.await?;
return Err(Error::ClientError);
}
};
query_router.update_pool_settings(pool.settings.clone());
let current_shard = query_router.shard();

// Handle all custom protocol commands, if any.
Expand Down
69 changes: 57 additions & 12 deletions src/query_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;

use crate::config::Role;
use crate::pool::{ConnectionPool, PoolSettings};
use crate::pool::PoolSettings;
use crate::sharding::{Sharder, ShardingFunction};

/// Regexes used to parse custom commands.
Expand Down Expand Up @@ -91,16 +91,20 @@ impl QueryRouter {
}

/// Create a new instance of the query router. Each client gets its own.
pub fn new(target_pool: ConnectionPool) -> QueryRouter {
pub fn new() -> QueryRouter {
QueryRouter {
active_shard: None,
active_role: None,
query_parser_enabled: target_pool.settings.query_parser_enabled,
primary_reads_enabled: target_pool.settings.primary_reads_enabled,
pool_settings: target_pool.settings,
query_parser_enabled: false,
primary_reads_enabled: false,
pool_settings: PoolSettings::default(),
}
}

pub fn update_pool_settings(&mut self, pool_settings: PoolSettings) {
self.pool_settings = pool_settings;
}

/// Try to parse a command and execute it.
pub fn try_execute_command(&mut self, mut buf: BytesMut) -> Option<(Command, String)> {
let code = buf.get_u8() as char;
Expand Down Expand Up @@ -363,22 +367,24 @@ impl QueryRouter {

#[cfg(test)]
mod test {
use std::collections::HashMap;

use super::*;
use crate::messages::simple_query;
use bytes::BufMut;

#[test]
fn test_defaults() {
QueryRouter::setup();
let qr = QueryRouter::new(ConnectionPool::default());
let qr = QueryRouter::new();

assert_eq!(qr.role(), None);
}

#[test]
fn test_infer_role_replica() {
QueryRouter::setup();
let mut qr = QueryRouter::new(ConnectionPool::default());
let mut qr = QueryRouter::new();
assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None);
assert_eq!(qr.query_parser_enabled(), true);

Expand All @@ -402,7 +408,7 @@ mod test {
#[test]
fn test_infer_role_primary() {
QueryRouter::setup();
let mut qr = QueryRouter::new(ConnectionPool::default());
let mut qr = QueryRouter::new();

let queries = vec![
simple_query("UPDATE items SET name = 'pumpkin' WHERE id = 5"),
Expand All @@ -421,7 +427,7 @@ mod test {
#[test]
fn test_infer_role_primary_reads_enabled() {
QueryRouter::setup();
let mut qr = QueryRouter::new(ConnectionPool::default());
let mut qr = QueryRouter::new();
let query = simple_query("SELECT * FROM items WHERE id = 5");
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO on")) != None);

Expand All @@ -432,7 +438,7 @@ mod test {
#[test]
fn test_infer_role_parse_prepared() {
QueryRouter::setup();
let mut qr = QueryRouter::new(ConnectionPool::default());
let mut qr = QueryRouter::new();
qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'"));
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);

Expand Down Expand Up @@ -523,7 +529,7 @@ mod test {
#[test]
fn test_try_execute_command() {
QueryRouter::setup();
let mut qr = QueryRouter::new(ConnectionPool::default());
let mut qr = QueryRouter::new();

// SetShardingKey
let query = simple_query("SET SHARDING KEY TO 13");
Expand Down Expand Up @@ -600,7 +606,7 @@ mod test {
#[test]
fn test_enable_query_parser() {
QueryRouter::setup();
let mut qr = QueryRouter::new(ConnectionPool::default());
let mut qr = QueryRouter::new();
let query = simple_query("SET SERVER ROLE TO 'auto'");
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);

Expand All @@ -621,4 +627,43 @@ mod test {
assert!(qr.try_execute_command(query) != None);
assert!(qr.query_parser_enabled());
}

#[test]
fn test_update_from_pool_settings() {
QueryRouter::setup();

let pool_settings = PoolSettings {
pool_mode: "transaction".to_string(),
shards: HashMap::default(),
user: crate::config::User::default(),
default_role: Role::Replica.to_string(),
query_parser_enabled: true,
primary_reads_enabled: false,
sharding_function: "pg_bigint_hash".to_string(),
};
let mut qr = QueryRouter::new();
assert_eq!(qr.active_role, None);
assert_eq!(qr.active_shard, None);
assert_eq!(qr.query_parser_enabled, false);
assert_eq!(qr.primary_reads_enabled, false);

// Internal state must not be changed due to this, only defaults
qr.update_pool_settings(pool_settings.clone());

assert_eq!(qr.active_role, None);
assert_eq!(qr.active_shard, None);
assert_eq!(qr.query_parser_enabled, false);
assert_eq!(qr.primary_reads_enabled, false);

let q1 = simple_query("SET SERVER ROLE TO 'primary'");
assert!(qr.try_execute_command(q1) != None);
assert_eq!(qr.active_role.unwrap(), Role::Primary);

let q2 = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(q2) != None);
assert_eq!(
qr.active_role.unwrap().to_string(),
pool_settings.clone().default_role
);
}
}
3 changes: 2 additions & 1 deletion tests/ruby/.ruby-version
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
2.7.1
3.0.0

1 change: 1 addition & 0 deletions tests/ruby/Gemfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ source "https://rubygems.org"
gem "pg"
gem "activerecord"
gem "rubocop"
gem "toml", "~> 0.3.0"
5 changes: 5 additions & 0 deletions tests/ruby/Gemfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ GEM
parallel (1.22.1)
parser (3.1.2.0)
ast (~> 2.4.1)
parslet (2.0.0)
pg (1.3.2)
rainbow (3.1.1)
regexp_parser (2.3.1)
Expand All @@ -35,17 +36,21 @@ GEM
rubocop-ast (1.17.0)
parser (>= 3.1.1.0)
ruby-progressbar (1.11.0)
toml (0.3.0)
parslet (>= 1.8.0, < 3.0.0)
tzinfo (2.0.4)
concurrent-ruby (~> 1.0)
unicode-display_width (2.1.0)

PLATFORMS
arm64-darwin-21
x86_64-linux

DEPENDENCIES
activerecord
pg
rubocop
toml (~> 0.3.0)

BUNDLED WITH
2.3.7
Loading