diff --git a/.circleci/run_tests.sh b/.circleci/run_tests.sh index 1585ebd8..6ffef8ba 100644 --- a/.circleci/run_tests.sh +++ b/.circleci/run_tests.sh @@ -90,8 +90,8 @@ kill -SIGHUP $(pgrep pgcat) # Reload config again cd tests/ruby sudo gem install bundler bundle install -bundle exec ruby tests.rb -bundle exec rspec *_spec.rb +bundle exec ruby tests.rb || exit 1 +bundle exec rspec *_spec.rb || exit 1 cd ../.. # @@ -99,7 +99,7 @@ cd ../.. # These tests will start and stop the pgcat server so it will need to be restarted after the tests # pip3 install -r tests/python/requirements.txt -python3 tests/python/tests.py +python3 tests/python/tests.py || exit 1 start_pgcat "info" diff --git a/src/client.rs b/src/client.rs index 419448fb..806547d9 100644 --- a/src/client.rs +++ b/src/client.rs @@ -59,6 +59,7 @@ pub struct Client { client_server_map: ClientServerMap, /// Client parameters, e.g. user, client_encoding, etc. + #[allow(dead_code)] parameters: HashMap, /// Statistics @@ -82,6 +83,9 @@ pub struct Client { /// Postgres user for this client (This comes from the user in the connection string) username: String, + /// Application name for this client (defaults to pgcat) + application_name: String, + /// Used to notify clients about an impending shutdown shutdown: Receiver<()>, } @@ -365,6 +369,11 @@ where None => return Err(Error::ClientError), }; + let application_name = match parameters.get("application_name") { + Some(application_name) => application_name, + None => "pgcat", + }; + let admin = ["pgcat", "pgbouncer"] .iter() .filter(|db| *db == &pool_name) @@ -493,6 +502,7 @@ where last_server_id: None, pool_name: pool_name.clone(), username: username.clone(), + application_name: application_name.to_string(), shutdown, connected_to_server: false, }); @@ -526,6 +536,7 @@ where last_server_id: None, pool_name: String::from("undefined"), username: String::from("undefined"), + application_name: String::from("undefined"), shutdown, connected_to_server: false, }); @@ -759,13 +770,10 @@ where server.address() ); - // Set application_name if any. // TODO: investigate other parameters and set them too. - if self.parameters.contains_key("application_name") { - server - .set_name(&self.parameters["application_name"]) - .await?; - } + + // Set application_name. + server.set_name(&self.application_name).await?; // Transaction loop. Multiple queries can be issued by the client here. // The connection belongs to the client until the transaction is over, @@ -782,12 +790,7 @@ where Err(err) => { // Client disconnected inside a transaction. // Clean up the server and re-use it. - // This prevents connection thrashing by bad clients. - if server.in_transaction() { - server.query("ROLLBACK").await?; - server.query("DISCARD ALL").await?; - server.set_name("pgcat").await?; - } + server.checkin_cleanup().await?; return Err(err); } @@ -829,16 +832,7 @@ where // Terminate 'X' => { - // Client closing. Rollback and clean up - // connection before releasing into the pool. - // Pgbouncer closes the connection which leads to - // connection thrashing when clients misbehave. - if server.in_transaction() { - server.query("ROLLBACK").await?; - server.query("DISCARD ALL").await?; - server.set_name("pgcat").await?; - } - + server.checkin_cleanup().await?; self.release(); return Ok(()); @@ -942,8 +936,10 @@ where // The server is no longer bound to us, we can't cancel it's queries anymore. debug!("Releasing server back into the pool"); + server.checkin_cleanup().await?; self.stats.server_idle(server.process_id(), address.id); self.connected_to_server = false; + self.release(); self.stats.client_idle(self.process_id, address.id); } diff --git a/src/server.rs b/src/server.rs index 3134a65d..0a53f115 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,7 +1,8 @@ /// Implementation of the PostgreSQL server (database) protocol. /// Here we are pretending to the a Postgres client. use bytes::{Buf, BufMut, BytesMut}; -use log::{debug, error, info, trace}; +use log::{debug, error, info, trace, warn}; +use std::io::Read; use std::time::SystemTime; use tokio::io::{AsyncReadExt, BufReader}; use tokio::net::{ @@ -48,6 +49,9 @@ pub struct Server { /// Is the server broken? We'll remote it from the pool if so. bad: bool, + /// If server connection requires a DISCARD ALL before checkin + needs_cleanup: bool, + /// Mapping of clients and servers used for query cancellation. client_server_map: ClientServerMap, @@ -316,6 +320,7 @@ impl Server { in_transaction: false, data_available: false, bad: false, + needs_cleanup: false, client_server_map: client_server_map, connected_at: chrono::offset::Utc::now().naive_utc(), stats: stats, @@ -440,6 +445,29 @@ impl Server { break; } + // CommandComplete + 'C' => { + let mut command_tag = String::new(); + match message.reader().read_to_string(&mut command_tag) { + Ok(_) => { + // Non-exhaustive list of commands that are likely to change session variables/resources + // which can leak between clients. This is a best effort to block bad clients + // from poisoning a transaction-mode pool by setting inappropriate session variables + match command_tag.as_str() { + "SET\0" | "PREPARE\0" => { + debug!("Server connection marked for clean up"); + self.needs_cleanup = true; + } + _ => (), + } + } + + Err(err) => { + warn!("Encountered an error while parsing CommandTag {}", err); + } + } + } + // DataRow 'D' => { // More data is available after this message, this is not the end of the reply. @@ -553,14 +581,43 @@ impl Server { Ok(()) } + /// Perform any necessary cleanup before putting the server + /// connection back in the pool + pub async fn checkin_cleanup(&mut self) -> Result<(), Error> { + // Client disconnected with an open transaction on the server connection. + // Pgbouncer behavior is to close the server connection but that can cause + // server connection thrashing if clients repeatedly do this. + // Instead, we ROLLBACK that transaction before putting the connection back in the pool + if self.in_transaction() { + self.query("ROLLBACK").await?; + } + + // Client disconnected but it perfromed session-altering operations such as + // SET statement_timeout to 1 or create a prepared statement. We clear that + // to avoid leaking state between clients. For performance reasons we only + // send `DISCARD ALL` if we think the session is altered instead of just sending + // it before each checkin. + if self.needs_cleanup { + self.query("DISCARD ALL").await?; + self.needs_cleanup = false; + } + + return Ok(()); + } + /// A shorthand for `SET application_name = $1`. - #[allow(dead_code)] pub async fn set_name(&mut self, name: &str) -> Result<(), Error> { if self.application_name != name { self.application_name = name.to_string(); - Ok(self + // We don't want `SET application_name` to mark the server connection + // as needing cleanup + let needs_cleanup_before = self.needs_cleanup; + + let result = Ok(self .query(&format!("SET application_name = '{}'", name)) - .await?) + .await?); + self.needs_cleanup = needs_cleanup_before; + return result; } else { Ok(()) } diff --git a/tests/ruby/helpers/pgcat_helper.rb b/tests/ruby/helpers/pgcat_helper.rb index 30b2bc82..80ac9dab 100644 --- a/tests/ruby/helpers/pgcat_helper.rb +++ b/tests/ruby/helpers/pgcat_helper.rb @@ -5,7 +5,7 @@ module Helpers module Pgcat - def self.three_shard_setup(pool_name, pool_size) + def self.three_shard_setup(pool_name, pool_size, pool_mode="transaction") user = { "password" => "sharding_user", "pool_size" => pool_size, @@ -22,7 +22,7 @@ def self.three_shard_setup(pool_name, pool_size) pgcat_cfg["pools"] = { "#{pool_name}" => { "default_role" => "any", - "pool_mode" => "transaction", + "pool_mode" => pool_mode, "primary_reads_enabled" => false, "query_parser_enabled" => false, "sharding_function" => "pg_bigint_hash", @@ -46,7 +46,7 @@ def self.three_shard_setup(pool_name, pool_size) end end - def self.single_shard_setup(pool_name, pool_size) + def self.single_shard_setup(pool_name, pool_size, pool_mode="transaction") user = { "password" => "sharding_user", "pool_size" => pool_size, @@ -66,7 +66,7 @@ def self.single_shard_setup(pool_name, pool_size) pgcat_cfg["pools"] = { "#{pool_name}" => { "default_role" => "any", - "pool_mode" => "transaction", + "pool_mode" => pool_mode, "primary_reads_enabled" => false, "query_parser_enabled" => false, "sharding_function" => "pg_bigint_hash", diff --git a/tests/ruby/misc_spec.rb b/tests/ruby/misc_spec.rb index 9aee49af..d5b529a9 100644 --- a/tests/ruby/misc_spec.rb +++ b/tests/ruby/misc_spec.rb @@ -91,7 +91,6 @@ conn.close expect(processes.primary.count_query("ROLLBACK")).to eq(1) - expect(processes.primary.count_query("DISCARD ALL")).to eq(1) end end @@ -106,4 +105,82 @@ admin_conn.close end end + + describe "State clearance" do + context "session mode" do + let(:processes) { Helpers::Pgcat.single_shard_setup("sharded_db", 5, "session") } + + it "Clears state before connection checkin" do + # Both modes of operation should not raise + # ERROR: prepared statement "prepared_q" already exists + 15.times do + conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + conn.async_exec("PREPARE prepared_q (int) AS SELECT $1") + conn.close + end + + conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + initial_value = conn.async_exec("SHOW statement_timeout")[0]["statement_timeout"] + conn.async_exec("SET statement_timeout to 1000") + current_value = conn.async_exec("SHOW statement_timeout")[0]["statement_timeout"] + expect(conn.async_exec("SHOW statement_timeout")[0]["statement_timeout"]).to eq("1s") + conn.close + end + + it "Does not send DISCARD ALL unless necessary" do + 10.times do + conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + conn.async_exec("SET SERVER ROLE to 'primary'") + conn.async_exec("SELECT 1") + conn.close + end + + expect(processes.primary.count_query("DISCARD ALL")).to eq(0) + + 10.times do + conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + conn.async_exec("SET SERVER ROLE to 'primary'") + conn.async_exec("SELECT 1") + conn.async_exec("SET statement_timeout to 5000") + conn.close + end + + expect(processes.primary.count_query("DISCARD ALL")).to eq(10) + end + end + + context "transaction mode" do + let(:processes) { Helpers::Pgcat.single_shard_setup("sharded_db", 5, "transaction") } + it "Clears state before connection checkin" do + # Both modes of operation should not raise + # ERROR: prepared statement "prepared_q" already exists + 15.times do + conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + conn.async_exec("PREPARE prepared_q (int) AS SELECT $1") + conn.close + end + end + + it "Does not send DISCARD ALL unless necessary" do + 10.times do + conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + conn.async_exec("SET SERVER ROLE to 'primary'") + conn.async_exec("SELECT 1") + conn.close + end + + expect(processes.primary.count_query("DISCARD ALL")).to eq(0) + + 10.times do + conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + conn.async_exec("SET SERVER ROLE to 'primary'") + conn.async_exec("SELECT 1") + conn.async_exec("SET statement_timeout to 5000") + conn.close + end + + expect(processes.primary.count_query("DISCARD ALL")).to eq(10) + end + end + end end