diff --git a/src/client.rs b/src/client.rs index 5b46a5c9..806547d9 100644 --- a/src/client.rs +++ b/src/client.rs @@ -59,6 +59,7 @@ pub struct Client { client_server_map: ClientServerMap, /// Client parameters, e.g. user, client_encoding, etc. + #[allow(dead_code)] parameters: HashMap, /// Statistics @@ -82,6 +83,9 @@ pub struct Client { /// Postgres user for this client (This comes from the user in the connection string) username: String, + /// Application name for this client (defaults to pgcat) + application_name: String, + /// Used to notify clients about an impending shutdown shutdown: Receiver<()>, } @@ -365,6 +369,11 @@ where None => return Err(Error::ClientError), }; + let application_name = match parameters.get("application_name") { + Some(application_name) => application_name, + None => "pgcat", + }; + let admin = ["pgcat", "pgbouncer"] .iter() .filter(|db| *db == &pool_name) @@ -493,6 +502,7 @@ where last_server_id: None, pool_name: pool_name.clone(), username: username.clone(), + application_name: application_name.to_string(), shutdown, connected_to_server: false, }); @@ -526,6 +536,7 @@ where last_server_id: None, pool_name: String::from("undefined"), username: String::from("undefined"), + application_name: String::from("undefined"), shutdown, connected_to_server: false, }); @@ -759,13 +770,10 @@ where server.address() ); - // Set application_name if any. // TODO: investigate other parameters and set them too. - if self.parameters.contains_key("application_name") { - server - .set_name(&self.parameters["application_name"]) - .await?; - } + + // Set application_name. + server.set_name(&self.application_name).await?; // Transaction loop. Multiple queries can be issued by the client here. // The connection belongs to the client until the transaction is over, diff --git a/src/server.rs b/src/server.rs index 208385ef..3e5cf90a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,6 +2,7 @@ /// Here we are pretending to the a Postgres client. use bytes::{Buf, BufMut, BytesMut}; use log::{debug, error, info, trace}; +use std::io::Read; use std::time::SystemTime; use tokio::io::{AsyncReadExt, BufReader}; use tokio::net::{ @@ -446,15 +447,14 @@ impl Server { // CommandComplete 'C' => { - let full_message = String::from_utf8_lossy(message.as_ref()); - let mut it = full_message.split_ascii_whitespace(); - let command_tag = it.next().unwrap().trim_end_matches(char::from(0)); + let mut command_tag = String::new(); + message.reader().read_to_string(&mut command_tag).unwrap(); - // Non-exhuastive list of commands that are likely to change session variables/resources + // Non-exhaustive list of commands that are likely to change session variables/resources // which can leak between client. This is a best effort to block bad clients // from poisoning a transaction-mode pool by setting inappropriate session variables - match command_tag { - "SET" | "PREPARE" => { + match command_tag.as_str() { + "SET\0" | "PREPARE\0" => { debug!("Server connection marked for clean up"); self.needs_cleanup = true; } @@ -585,13 +585,10 @@ impl Server { self.needs_cleanup = false; } - self.set_name("pgcat").await?; - return Ok(()); } /// A shorthand for `SET application_name = $1`. - #[allow(dead_code)] pub async fn set_name(&mut self, name: &str) -> Result<(), Error> { if self.application_name != name { self.application_name = name.to_string();