diff --git a/src/client.rs b/src/client.rs index 1775ad22..82ad5bf9 100644 --- a/src/client.rs +++ b/src/client.rs @@ -25,6 +25,13 @@ enum ClientConnectionType { CancelQuery, } +#[derive(Clone, Copy, Debug)] +pub enum ClientRoutingMode { + Default, + Reader, + Writer, +} + /// The client state. One of these is created per client. pub struct Client { /// The reads are buffered (8K by default). @@ -73,6 +80,8 @@ pub struct Client { last_server_id: Option, target_pool: ConnectionPool, + + routing_mode: ClientRoutingMode, } /// Client entrypoint. @@ -264,10 +273,50 @@ where trace!("Got StartupMessage"); let parameters = parse_startup(bytes.clone())?; - let database = match parameters.get("database") { + + let database_param = match parameters.get("database") { Some(db) => db, None => return Err(Error::ClientError), }; + let database_name_parts = database_param.split("/").collect::>(); + let (database_name, routing_mode) = match database_name_parts.len() { + 1 => ( + database_name_parts[0].to_string(), + ClientRoutingMode::Default, + ), + 2 => match database_name_parts[1] { + "reader" => { + info!("Client connected in force reader mode"); + ( + database_name_parts[0].to_string(), + ClientRoutingMode::Reader, + ) + } + "writer" => { + info!("Client connected in force writer mode"); + ( + database_name_parts[0].to_string(), + ClientRoutingMode::Writer, + ) + } + _ => { + error_response( + &mut write, + &format!("Invalid database mode {}", database_name_parts[1]), + ) + .await?; + return Err(Error::ClientError); + } + }, + _ => { + error_response( + &mut write, + &format!("Invalid database name {}", database_param), + ) + .await?; + return Err(Error::ClientError); + } + }; let user = match parameters.get("user") { Some(user) => user, @@ -276,7 +325,7 @@ where let admin = ["pgcat", "pgbouncer"] .iter() - .filter(|db| *db == &database) + .filter(|db| *db == &database_name) .count() == 1; @@ -328,14 +377,14 @@ where generate_server_info_for_admin(), ) } else { - let target_pool = match get_pool(database.clone(), user.clone()) { + let target_pool = match get_pool(database_name.clone(), user.clone()) { Some(pool) => pool, None => { error_response( &mut write, &format!( "No pool configured for database: {:?}, user: {:?}", - database, user + database_name, user ), ) .await?; @@ -375,6 +424,7 @@ where buffer: BytesMut::with_capacity(8196), cancel_mode: false, transaction_mode: transaction_mode, + routing_mode: routing_mode, process_id: process_id, secret_key: secret_key, client_server_map: client_server_map, @@ -404,6 +454,7 @@ where buffer: BytesMut::with_capacity(8196), cancel_mode: true, transaction_mode: false, + routing_mode: ClientRoutingMode::Default, process_id: process_id, secret_key: secret_key, client_server_map: client_server_map, @@ -450,7 +501,9 @@ where // The query router determines where the query is going to go, // e.g. primary, replica, which shard. - let mut query_router = QueryRouter::new(self.target_pool.clone()); + let mut query_router = + QueryRouter::new(self.target_pool.clone(), self.routing_mode.clone()); + let mut round_robin = 0; // Our custom protocol loop. diff --git a/src/query_router.rs b/src/query_router.rs index d597b81e..d6484ad1 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -8,6 +8,7 @@ use sqlparser::ast::Statement::{Query, StartTransaction}; use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::Parser; +use crate::client::ClientRoutingMode; use crate::config::Role; use crate::pool::{ConnectionPool, PoolSettings}; use crate::sharding::{Sharder, ShardingFunction}; @@ -56,6 +57,8 @@ pub struct QueryRouter { primary_reads_enabled: bool, pool_settings: PoolSettings, + + client_routing_mode: ClientRoutingMode, } impl QueryRouter { @@ -91,13 +94,14 @@ impl QueryRouter { } /// Create a new instance of the query router. Each client gets its own. - pub fn new(target_pool: ConnectionPool) -> QueryRouter { + pub fn new(target_pool: ConnectionPool, client_routing_mode: ClientRoutingMode) -> QueryRouter { QueryRouter { active_shard: None, active_role: None, query_parser_enabled: target_pool.settings.query_parser_enabled, primary_reads_enabled: target_pool.settings.primary_reads_enabled, pool_settings: target_pool.settings, + client_routing_mode: client_routing_mode, } } @@ -339,7 +343,11 @@ impl QueryRouter { /// Get the current desired server role we should be talking to. pub fn role(&self) -> Option { - self.active_role + match self.client_routing_mode { + ClientRoutingMode::Default => self.active_role, + ClientRoutingMode::Reader => Some(Role::Replica), + ClientRoutingMode::Writer => Some(Role::Primary), + } } /// Get desired shard we should be talking to. @@ -370,7 +378,7 @@ mod test { #[test] fn test_defaults() { QueryRouter::setup(); - let qr = QueryRouter::new(ConnectionPool::default()); + let qr = QueryRouter::new(ConnectionPool::default(), ClientRoutingMode::Default); assert_eq!(qr.role(), None); } @@ -378,7 +386,7 @@ mod test { #[test] fn test_infer_role_replica() { QueryRouter::setup(); - let mut qr = QueryRouter::new(ConnectionPool::default()); + let mut qr = QueryRouter::new(ConnectionPool::default(), ClientRoutingMode::Default); assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None); assert_eq!(qr.query_parser_enabled(), true); @@ -402,7 +410,7 @@ mod test { #[test] fn test_infer_role_primary() { QueryRouter::setup(); - let mut qr = QueryRouter::new(ConnectionPool::default()); + let mut qr = QueryRouter::new(ConnectionPool::default(), ClientRoutingMode::Default); let queries = vec![ simple_query("UPDATE items SET name = 'pumpkin' WHERE id = 5"), @@ -421,7 +429,7 @@ mod test { #[test] fn test_infer_role_primary_reads_enabled() { QueryRouter::setup(); - let mut qr = QueryRouter::new(ConnectionPool::default()); + let mut qr = QueryRouter::new(ConnectionPool::default(), ClientRoutingMode::Default); let query = simple_query("SELECT * FROM items WHERE id = 5"); assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO on")) != None); @@ -432,7 +440,7 @@ mod test { #[test] fn test_infer_role_parse_prepared() { QueryRouter::setup(); - let mut qr = QueryRouter::new(ConnectionPool::default()); + let mut qr = QueryRouter::new(ConnectionPool::default(), ClientRoutingMode::Default); qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")); assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); @@ -523,7 +531,7 @@ mod test { #[test] fn test_try_execute_command() { QueryRouter::setup(); - let mut qr = QueryRouter::new(ConnectionPool::default()); + let mut qr = QueryRouter::new(ConnectionPool::default(), ClientRoutingMode::Default); // SetShardingKey let query = simple_query("SET SHARDING KEY TO 13"); @@ -600,7 +608,7 @@ mod test { #[test] fn test_enable_query_parser() { QueryRouter::setup(); - let mut qr = QueryRouter::new(ConnectionPool::default()); + let mut qr = QueryRouter::new(ConnectionPool::default(), ClientRoutingMode::Default); let query = simple_query("SET SERVER ROLE TO 'auto'"); assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); @@ -621,4 +629,33 @@ mod test { assert!(qr.try_execute_command(query) != None); assert!(qr.query_parser_enabled()); } + + #[test] + fn test_client_routing_mode() { + QueryRouter::setup(); + let mut qr = QueryRouter::new(ConnectionPool::default(), ClientRoutingMode::Reader); + let query = simple_query("SET SERVER ROLE TO 'auto'"); + assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); + + assert!(qr.try_execute_command(query) != None); + assert!(qr.query_parser_enabled()); + assert_eq!(qr.role(), Some(Role::Replica)); + + let query = simple_query("BEGIN"); + assert_eq!(qr.infer_role(query), true); + assert_eq!(qr.role(), Some(Role::Replica)); + + let query = simple_query("INSERT INTO test_table VALUES (1)"); + assert_eq!(qr.infer_role(query), true); + assert_eq!(qr.role(), Some(Role::Replica)); + + let query = simple_query("SELECT * FROM test_table"); + assert_eq!(qr.infer_role(query), true); + assert_eq!(qr.role(), Some(Role::Replica)); + + assert!(qr.query_parser_enabled()); + let query = simple_query("SET SERVER ROLE TO 'default'"); + assert!(qr.try_execute_command(query) != None); + assert!(qr.query_parser_enabled()); + } } diff --git a/tests/ruby/tests.rb b/tests/ruby/tests.rb index c5a55a7e..8fafa2dd 100644 --- a/tests/ruby/tests.rb +++ b/tests/ruby/tests.rb @@ -130,6 +130,25 @@ def poorly_behaved_client end +# Test reader/writer endpoints +def test_reader_writer_endpoints + conn = PG::connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db/reader?application_name=testing_pgcat") + conn.async_exec 'BEGIN' + conn.async_exec 'SELECT 1' + conn.async_exec 'COMMIT' + conn.close + + conn = PG::connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db/writer?application_name=testing_pgcat") + conn.async_exec 'BEGIN' + conn.async_exec 'SELECT 1' + conn.async_exec 'COMMIT' + conn.close + + puts 'Reader/Writer clients ok' +end + +test_reader_writer_endpoints + def test_server_parameters server_conn = PG::connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db?application_name=testing_pgcat") raise StandardError, "Bad server version" if server_conn.server_version == 0 @@ -141,3 +160,5 @@ def test_server_parameters puts 'Server parameters ok' end + +test_server_parameters