diff --git a/.circleci/config.yml b/.circleci/config.yml index c8344911..07224112 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -63,6 +63,9 @@ jobs: - run: name: "Lint" command: "cargo fmt --check" + - run: + name: "Clippy" + command: "cargo clippy --all --all-targets -- -Dwarnings" - run: name: "Tests" command: "cargo clean && cargo build && cargo test && bash .circleci/run_tests.sh && .circleci/generate_coverage.sh" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 10d4924a..e0d5d160 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,7 +2,7 @@ Thank you for contributing! Just a few tips here: -1. `cargo fmt` your code before opening up a PR +1. `cargo fmt` and `cargo clippy` your code before opening up a PR 2. Run the test suite (e.g. `pgbench`) to make sure everything still works. The tests are in `.circleci/run_tests.sh`. 3. Performance is important, make sure there are no regressions in your branch vs. `main`. diff --git a/src/admin.rs b/src/admin.rs index da925292..f1b0c63f 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -283,7 +283,7 @@ where { let mut res = BytesMut::new(); - let detail_msg = vec![ + let detail_msg = [ "", "SHOW HELP|CONFIG|DATABASES|POOLS|CLIENTS|SERVERS|USERS|VERSION", // "SHOW PEERS|PEER_POOLS", // missing PEERS|PEER_POOLS @@ -301,7 +301,6 @@ where // "KILL ", // "SUSPEND", "SHUTDOWN", - // "WAIT_CLOSE []", // missing ]; res.put(notify("Console usage", detail_msg.join("\n\t"))); @@ -802,7 +801,7 @@ where T: tokio::io::AsyncWrite + std::marker::Unpin, { let parts: Vec<&str> = match tokens.len() == 2 { - true => tokens[1].split(",").map(|part| part.trim()).collect(), + true => tokens[1].split(',').map(|part| part.trim()).collect(), false => Vec::new(), }; @@ -865,7 +864,7 @@ where T: tokio::io::AsyncWrite + std::marker::Unpin, { let parts: Vec<&str> = match tokens.len() == 2 { - true => tokens[1].split(",").map(|part| part.trim()).collect(), + true => tokens[1].split(',').map(|part| part.trim()).collect(), false => Vec::new(), }; diff --git a/src/client.rs b/src/client.rs index 2ec5b6a7..98a0669c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -131,7 +131,7 @@ pub async fn client_entrypoint( // Client requested a TLS connection. Ok((ClientConnectionType::Tls, _)) => { // TLS settings are configured, will setup TLS now. - if tls_certificate != None { + if tls_certificate.is_some() { debug!("Accepting TLS request"); let mut yes = BytesMut::new(); @@ -448,7 +448,7 @@ where None => "pgcat", }; - let client_identifier = ClientIdentifier::new(&application_name, &username, &pool_name); + let client_identifier = ClientIdentifier::new(application_name, username, pool_name); let admin = ["pgcat", "pgbouncer"] .iter() @@ -795,7 +795,7 @@ where let mut will_prepare = false; let client_identifier = ClientIdentifier::new( - &self.server_parameters.get_application_name(), + self.server_parameters.get_application_name(), &self.username, &self.pool_name, ); @@ -982,15 +982,11 @@ where } // Check on plugin results. - match plugin_output { - Some(PluginOutput::Deny(error)) => { - self.buffer.clear(); - error_response(&mut self.write, &error).await?; - plugin_output = None; - continue; - } - - _ => (), + if let Some(PluginOutput::Deny(error)) = plugin_output { + self.buffer.clear(); + error_response(&mut self.write, &error).await?; + plugin_output = None; + continue; }; // Check if the pool is paused and wait until it's resumed. @@ -1267,7 +1263,7 @@ where // Safe to unwrap because we know this message has a certain length and has the code // This reads the first byte without advancing the internal pointer and mutating the bytes - let code = *message.get(0).unwrap() as char; + let code = *message.first().unwrap() as char; trace!("Message: {}", code); @@ -1325,7 +1321,7 @@ where self.stats.transaction(); server .stats() - .transaction(&self.server_parameters.get_application_name()); + .transaction(self.server_parameters.get_application_name()); // Release server back to the pool if we are in transaction mode. // If we are in session mode, we keep the server until the client disconnects. @@ -1400,13 +1396,10 @@ where let close: Close = (&message).try_into()?; if close.is_prepared_statement() && !close.anonymous() { - match self.prepared_statements.get(&close.name) { - Some(parse) => { - server.will_close(&parse.generated_name); - } - + if let Some(parse) = self.prepared_statements.get(&close.name) { + server.will_close(&parse.generated_name); + } else { // A prepared statement slipped through? Not impossible, since we don't support PREPARE yet. - None => (), }; } } @@ -1445,7 +1438,7 @@ where self.buffer.put(&message[..]); - let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char; + let first_message_code = (*self.buffer.first().unwrap_or(&0)) as char; // Almost certainly true if first_message_code == 'P' && !prepared_statements_enabled { @@ -1477,7 +1470,7 @@ where self.stats.transaction(); server .stats() - .transaction(&self.server_parameters.get_application_name()); + .transaction(self.server_parameters.get_application_name()); // Release server back to the pool if we are in transaction mode. // If we are in session mode, we keep the server until the client disconnects. @@ -1739,7 +1732,7 @@ where client_stats.query(); server.stats().query( Instant::now().duration_since(query_start).as_millis() as u64, - &self.server_parameters.get_application_name(), + self.server_parameters.get_application_name(), ); Ok(()) diff --git a/src/cmd_args.rs b/src/cmd_args.rs index 3989d670..1abb7ed9 100644 --- a/src/cmd_args.rs +++ b/src/cmd_args.rs @@ -25,7 +25,7 @@ pub struct Args { } pub fn parse() -> Args { - return Args::parse(); + Args::parse() } #[derive(ValueEnum, Clone, Debug)] diff --git a/src/config.rs b/src/config.rs index 90b4beb3..f91e488e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -236,18 +236,14 @@ impl Default for User { impl User { fn validate(&self) -> Result<(), Error> { - match self.min_pool_size { - Some(min_pool_size) => { - if min_pool_size > self.pool_size { - error!( - "min_pool_size of {} cannot be larger than pool_size of {}", - min_pool_size, self.pool_size - ); - return Err(Error::BadConfig); - } + if let Some(min_pool_size) = self.min_pool_size { + if min_pool_size > self.pool_size { + error!( + "min_pool_size of {} cannot be larger than pool_size of {}", + min_pool_size, self.pool_size + ); + return Err(Error::BadConfig); } - - None => (), }; Ok(()) @@ -677,9 +673,9 @@ impl Pool { Some(key) => { // No quotes in the key so we don't have to compare quoted // to unquoted idents. - let key = key.replace("\"", ""); + let key = key.replace('\"', ""); - if key.split(".").count() != 2 { + if key.split('.').count() != 2 { error!( "automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`", key, key @@ -692,17 +688,14 @@ impl Pool { None => None, }; - match self.default_shard { - DefaultShard::Shard(shard_number) => { - if shard_number >= self.shards.len() { - error!("Invalid shard {:?}", shard_number); - return Err(Error::BadConfig); - } + if let DefaultShard::Shard(shard_number) = self.default_shard { + if shard_number >= self.shards.len() { + error!("Invalid shard {:?}", shard_number); + return Err(Error::BadConfig); } - _ => (), } - for (_, user) in &self.users { + for user in self.users.values() { user.validate()?; } @@ -777,8 +770,8 @@ impl<'de> serde::Deserialize<'de> for DefaultShard { D: Deserializer<'de>, { let s = String::deserialize(deserializer)?; - if s.starts_with("shard_") { - let shard = s[6..].parse::().map_err(serde::de::Error::custom)?; + if let Some(s) = s.strip_prefix("shard_") { + let shard = s.parse::().map_err(serde::de::Error::custom)?; return Ok(DefaultShard::Shard(shard)); } @@ -874,7 +867,7 @@ pub trait Plugin { impl std::fmt::Display for Plugins { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { fn is_enabled(arg: Option<&T>) -> bool { - if let Some(ref arg) = arg { + if let Some(arg) = arg { arg.is_enabled() } else { false @@ -955,6 +948,7 @@ pub struct Query { } impl Query { + #[allow(clippy::needless_range_loop)] pub fn substitute(&mut self, db: &str, user: &str) { for col in self.result.iter_mut() { for i in 0..col.len() { @@ -1079,8 +1073,8 @@ impl From<&Config> for std::collections::HashMap { ( format!("pools.{:?}.users", pool_name), pool.users - .iter() - .map(|(_username, user)| &user.username) + .values() + .map(|user| &user.username) .cloned() .collect::>() .join(", "), @@ -1165,13 +1159,9 @@ impl Config { Some(tls_certificate) => { info!("TLS certificate: {}", tls_certificate); - match self.general.tls_private_key.clone() { - Some(tls_private_key) => { - info!("TLS private key: {}", tls_private_key); - info!("TLS support is enabled"); - } - - None => (), + if let Some(tls_private_key) = self.general.tls_private_key.clone() { + info!("TLS private key: {}", tls_private_key); + info!("TLS support is enabled"); } } @@ -1206,8 +1196,8 @@ impl Config { pool_name, pool_config .users - .iter() - .map(|(_, user_cfg)| user_cfg.pool_size) + .values() + .map(|user_cfg| user_cfg.pool_size) .sum::() .to_string() ); @@ -1377,34 +1367,31 @@ impl Config { } // Validate TLS! - match self.general.tls_certificate.clone() { - Some(tls_certificate) => { - match load_certs(Path::new(&tls_certificate)) { - Ok(_) => { - // Cert is okay, but what about the private key? - match self.general.tls_private_key.clone() { - Some(tls_private_key) => match load_keys(Path::new(&tls_private_key)) { - Ok(_) => (), - Err(err) => { - error!("tls_private_key is incorrectly configured: {:?}", err); - return Err(Error::BadConfig); - } - }, - - None => { - error!("tls_certificate is set, but the tls_private_key is not"); + if let Some(tls_certificate) = self.general.tls_certificate.clone() { + match load_certs(Path::new(&tls_certificate)) { + Ok(_) => { + // Cert is okay, but what about the private key? + match self.general.tls_private_key.clone() { + Some(tls_private_key) => match load_keys(Path::new(&tls_private_key)) { + Ok(_) => (), + Err(err) => { + error!("tls_private_key is incorrectly configured: {:?}", err); return Err(Error::BadConfig); } - }; - } + }, - Err(err) => { - error!("tls_certificate is incorrectly configured: {:?}", err); - return Err(Error::BadConfig); - } + None => { + error!("tls_certificate is set, but the tls_private_key is not"); + return Err(Error::BadConfig); + } + }; + } + + Err(err) => { + error!("tls_certificate is incorrectly configured: {:?}", err); + return Err(Error::BadConfig); } } - None => (), }; for pool in self.pools.values_mut() { diff --git a/src/messages.rs b/src/messages.rs index 07fe9317..86036a92 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -163,12 +163,10 @@ where match stream.write_all(&startup).await { Ok(_) => Ok(()), - Err(err) => { - return Err(Error::SocketError(format!( - "Error writing startup to server socket - Error: {:?}", - err - ))) - } + Err(err) => Err(Error::SocketError(format!( + "Error writing startup to server socket - Error: {:?}", + err + ))), } } @@ -244,8 +242,8 @@ pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec { let mut md5 = Md5::new(); // First pass - md5.update(&password.as_bytes()); - md5.update(&user.as_bytes()); + md5.update(password.as_bytes()); + md5.update(user.as_bytes()); let output = md5.finalize_reset(); @@ -281,7 +279,7 @@ where { let password = md5_hash_password(user, password, salt); - let mut message = BytesMut::with_capacity(password.len() as usize + 5); + let mut message = BytesMut::with_capacity(password.len() + 5); message.put_u8(b'p'); message.put_i32(password.len() as i32 + 4); @@ -295,7 +293,7 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin, { let password = md5_hash_second_pass(hash, salt); - let mut message = BytesMut::with_capacity(password.len() as usize + 5); + let mut message = BytesMut::with_capacity(password.len() + 5); message.put_u8(b'p'); message.put_i32(password.len() as i32 + 4); @@ -516,7 +514,7 @@ pub fn data_row_nullable(row: &Vec>) -> BytesMut { data_row.put_i32(column.len() as i32); data_row.put_slice(column); } else { - data_row.put_i32(-1 as i32); + data_row.put_i32(-1_i32); } } @@ -571,12 +569,10 @@ where { match stream.write_all(&buf).await { Ok(_) => Ok(()), - Err(err) => { - return Err(Error::SocketError(format!( - "Error writing to socket - Error: {:?}", - err - ))) - } + Err(err) => Err(Error::SocketError(format!( + "Error writing to socket - Error: {:?}", + err + ))), } } @@ -587,12 +583,10 @@ where { match stream.write_all(buf).await { Ok(_) => Ok(()), - Err(err) => { - return Err(Error::SocketError(format!( - "Error writing to socket - Error: {:?}", - err - ))) - } + Err(err) => Err(Error::SocketError(format!( + "Error writing to socket - Error: {:?}", + err + ))), } } @@ -603,19 +597,15 @@ where match stream.write_all(buf).await { Ok(_) => match stream.flush().await { Ok(_) => Ok(()), - Err(err) => { - return Err(Error::SocketError(format!( - "Error flushing socket - Error: {:?}", - err - ))) - } - }, - Err(err) => { - return Err(Error::SocketError(format!( - "Error writing to socket - Error: {:?}", + Err(err) => Err(Error::SocketError(format!( + "Error flushing socket - Error: {:?}", err - ))) - } + ))), + }, + Err(err) => Err(Error::SocketError(format!( + "Error writing to socket - Error: {:?}", + err + ))), } } @@ -730,7 +720,7 @@ impl BytesMutReader for Cursor<&BytesMut> { let mut buf = vec![]; match self.read_until(b'\0', &mut buf) { Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()), - Err(err) => return Err(Error::ParseBytesError(err.to_string())), + Err(err) => Err(Error::ParseBytesError(err.to_string())), } } } @@ -746,7 +736,7 @@ impl BytesMutReader for BytesMut { let string_bytes = self.split_to(index + 1); Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string()) } - None => return Err(Error::ParseBytesError("Could not read string".to_string())), + None => Err(Error::ParseBytesError("Could not read string".to_string())), } } } @@ -1311,38 +1301,38 @@ mod tests { fn parse_fields() { let mut complete_msg = vec![]; let severity = "FATAL"; - complete_msg.extend(field('S', &severity)); - complete_msg.extend(field('V', &severity)); + complete_msg.extend(field('S', severity)); + complete_msg.extend(field('V', severity)); let error_code = "29P02"; - complete_msg.extend(field('C', &error_code)); + complete_msg.extend(field('C', error_code)); let message = "password authentication failed for user \"wrong_user\""; - complete_msg.extend(field('M', &message)); + complete_msg.extend(field('M', message)); let detail_msg = "super detailed message"; - complete_msg.extend(field('D', &detail_msg)); + complete_msg.extend(field('D', detail_msg)); let hint_msg = "hint detail here"; - complete_msg.extend(field('H', &hint_msg)); + complete_msg.extend(field('H', hint_msg)); complete_msg.extend(field('P', "123")); complete_msg.extend(field('p', "234")); let internal_query = "SELECT * from foo;"; - complete_msg.extend(field('q', &internal_query)); + complete_msg.extend(field('q', internal_query)); let where_msg = "where goes here"; - complete_msg.extend(field('W', &where_msg)); + complete_msg.extend(field('W', where_msg)); let schema_msg = "schema_name"; - complete_msg.extend(field('s', &schema_msg)); + complete_msg.extend(field('s', schema_msg)); let table_msg = "table_name"; - complete_msg.extend(field('t', &table_msg)); + complete_msg.extend(field('t', table_msg)); let column_msg = "column_name"; - complete_msg.extend(field('c', &column_msg)); + complete_msg.extend(field('c', column_msg)); let data_type_msg = "type_name"; - complete_msg.extend(field('d', &data_type_msg)); + complete_msg.extend(field('d', data_type_msg)); let constraint_msg = "constraint_name"; - complete_msg.extend(field('n', &constraint_msg)); + complete_msg.extend(field('n', constraint_msg)); let file_msg = "pgcat.c"; - complete_msg.extend(field('F', &file_msg)); + complete_msg.extend(field('F', file_msg)); complete_msg.extend(field('L', "335")); let routine_msg = "my_failing_routine"; - complete_msg.extend(field('R', &routine_msg)); + complete_msg.extend(field('R', routine_msg)); tracing_subscriber::fmt() .with_max_level(tracing::Level::INFO) @@ -1378,11 +1368,11 @@ mod tests { ); let mut only_mandatory_msg = vec![]; - only_mandatory_msg.extend(field('S', &severity)); - only_mandatory_msg.extend(field('V', &severity)); - only_mandatory_msg.extend(field('C', &error_code)); - only_mandatory_msg.extend(field('M', &message)); - only_mandatory_msg.extend(field('D', &detail_msg)); + only_mandatory_msg.extend(field('S', severity)); + only_mandatory_msg.extend(field('V', severity)); + only_mandatory_msg.extend(field('C', error_code)); + only_mandatory_msg.extend(field('M', message)); + only_mandatory_msg.extend(field('D', detail_msg)); let err_fields = PgErrorMsg::parse(only_mandatory_msg.clone()).unwrap(); info!("only mandatory fields: {}", &err_fields); diff --git a/src/mirrors.rs b/src/mirrors.rs index 90bcd355..f704a8cd 100644 --- a/src/mirrors.rs +++ b/src/mirrors.rs @@ -137,18 +137,18 @@ impl MirroringManager { bytes_rx, disconnect_rx: exit_rx, }; - exit_senders.push(exit_tx.clone()); - byte_senders.push(bytes_tx.clone()); + exit_senders.push(exit_tx); + byte_senders.push(bytes_tx); client.start(); }); Self { - byte_senders: byte_senders, + byte_senders, disconnect_senders: exit_senders, } } - pub fn send(self: &mut Self, bytes: &BytesMut) { + pub fn send(&mut self, bytes: &BytesMut) { // We want to avoid performing an allocation if we won't be able to send the message // There is a possibility of a race here where we check the capacity and then the channel is // closed or the capacity is reduced to 0, but mirroring is best effort anyway @@ -170,7 +170,7 @@ impl MirroringManager { }); } - pub fn disconnect(self: &mut Self) { + pub fn disconnect(&mut self) { self.disconnect_senders .iter_mut() .for_each(|sender| match sender.try_send(()) { diff --git a/src/plugins/intercept.rs b/src/plugins/intercept.rs index 166294bc..d13ab073 100644 --- a/src/plugins/intercept.rs +++ b/src/plugins/intercept.rs @@ -92,7 +92,7 @@ impl<'a> Plugin for Intercept<'a> { .map(|s| { let s = s.as_str().to_string(); - if s == "" { + if s.is_empty() { None } else { Some(s) diff --git a/src/plugins/mod.rs b/src/plugins/mod.rs index 5ef6009a..f1076d06 100644 --- a/src/plugins/mod.rs +++ b/src/plugins/mod.rs @@ -33,6 +33,7 @@ pub enum PluginOutput { #[async_trait] pub trait Plugin { // Run before the query is sent to the server. + #[allow(clippy::ptr_arg)] async fn run( &mut self, query_router: &QueryRouter, diff --git a/src/plugins/prewarmer.rs b/src/plugins/prewarmer.rs index a09bbe9d..cd93db9a 100644 --- a/src/plugins/prewarmer.rs +++ b/src/plugins/prewarmer.rs @@ -20,7 +20,7 @@ impl<'a> Prewarmer<'a> { self.server.address(), query ); - self.server.query(&query).await?; + self.server.query(query).await?; } Ok(()) diff --git a/src/plugins/table_access.rs b/src/plugins/table_access.rs index 79c1260e..b8153b5a 100644 --- a/src/plugins/table_access.rs +++ b/src/plugins/table_access.rs @@ -34,7 +34,7 @@ impl<'a> Plugin for TableAccess<'a> { visit_relations(ast, |relation| { let relation = relation.to_string(); - let parts = relation.split(".").collect::>(); + let parts = relation.split('.').collect::>(); let table_name = parts.last().unwrap(); if self.tables.contains(&table_name.to_string()) { diff --git a/src/pool.rs b/src/pool.rs index 02dab273..736dc1ad 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -241,20 +241,17 @@ impl ConnectionPool { let old_pool_ref = get_pool(pool_name, &user.username); let identifier = PoolIdentifier::new(pool_name, &user.username); - match old_pool_ref { - Some(pool) => { - // If the pool hasn't changed, get existing reference and insert it into the new_pools. - // We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8). - if pool.config_hash == new_pool_hash_value { - info!( - "[pool: {}][user: {}] has not changed", - pool_name, user.username - ); - new_pools.insert(identifier.clone(), pool.clone()); - continue; - } + if let Some(pool) = old_pool_ref { + // If the pool hasn't changed, get existing reference and insert it into the new_pools. + // We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8). + if pool.config_hash == new_pool_hash_value { + info!( + "[pool: {}][user: {}] has not changed", + pool_name, user.username + ); + new_pools.insert(identifier.clone(), pool.clone()); + continue; } - None => (), } info!( @@ -399,7 +396,7 @@ impl ConnectionPool { }, }; - let reaper_rate = *vec![idle_timeout, server_lifetime, POOL_REAPER_RATE] + let reaper_rate = *[idle_timeout, server_lifetime, POOL_REAPER_RATE] .iter() .min() .unwrap(); @@ -489,7 +486,7 @@ impl ConnectionPool { .clone() .map(|regex| Regex::new(regex.as_str()).unwrap()), regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000), - default_shard: pool_config.default_shard.clone(), + default_shard: pool_config.default_shard, auth_query: pool_config.auth_query.clone(), auth_query_user: pool_config.auth_query_user.clone(), auth_query_password: pool_config.auth_query_password.clone(), @@ -678,7 +675,7 @@ impl ConnectionPool { let mut force_healthcheck = false; if self.is_banned(address) { - if self.try_unban(&address).await { + if self.try_unban(address).await { force_healthcheck = true; } else { debug!("Address {:?} is banned", address); @@ -806,8 +803,8 @@ impl ConnectionPool { // Don't leave a bad connection in the pool. server.mark_bad(); - self.ban(&address, BanReason::FailedHealthCheck, Some(client_info)); - return false; + self.ban(address, BanReason::FailedHealthCheck, Some(client_info)); + false } /// Ban an address (i.e. replica). It no longer will serve @@ -931,10 +928,10 @@ impl ConnectionPool { let guard = self.banlist.read(); for banlist in guard.iter() { for (address, (reason, timestamp)) in banlist.iter() { - bans.push((address.clone(), (reason.clone(), timestamp.clone()))); + bans.push((address.clone(), (reason.clone(), *timestamp))); } } - return bans; + bans } /// Get the address from the host url @@ -992,7 +989,7 @@ impl ConnectionPool { } let busy = provisioned - idle; debug!("{:?} has {:?} busy connections", address, busy); - return busy; + busy } fn valid_shard_id(&self, shard: Option) -> bool { @@ -1031,6 +1028,7 @@ pub struct ServerPool { } impl ServerPool { + #[allow(clippy::too_many_arguments)] pub fn new( address: Address, user: User, @@ -1043,7 +1041,7 @@ impl ServerPool { ) -> ServerPool { ServerPool { address, - user: user.clone(), + user, database: database.to_string(), client_server_map, auth_hash, diff --git a/src/query_router.rs b/src/query_router.rs index 189f2dcc..8b451dd3 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -91,7 +91,7 @@ impl QueryRouter { /// One-time initialization of regexes /// that parse our custom SQL protocol. pub fn setup() -> bool { - let set = match RegexSet::new(&CUSTOM_SQL_REGEXES) { + let set = match RegexSet::new(CUSTOM_SQL_REGEXES) { Ok(rgx) => rgx, Err(err) => { error!("QueryRouter::setup Could not compile regex set: {:?}", err); @@ -132,7 +132,7 @@ impl QueryRouter { self.pool_settings = pool_settings; } - pub fn pool_settings<'a>(&'a self) -> &'a PoolSettings { + pub fn pool_settings(&self) -> &PoolSettings { &self.pool_settings } @@ -148,7 +148,7 @@ impl QueryRouter { // Check for any sharding regex matches in any queries if comment_shard_routing_enabled { - match code as char { + match code { // For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement 'P' | 'Q' => { // Check only the first block of bytes configured by the pool settings @@ -344,16 +344,13 @@ impl QueryRouter { let code = message_cursor.get_u8() as char; let len = message_cursor.get_i32() as usize; - match self.pool_settings.query_parser_max_length { - Some(max_length) => { - if len > max_length { - return Err(Error::QueryRouterParserError(format!( - "Query too long for parser: {} > {}", - len, max_length - ))); - } + if let Some(max_length) = self.pool_settings.query_parser_max_length { + if len > max_length { + return Err(Error::QueryRouterParserError(format!( + "Query too long for parser: {} > {}", + len, max_length + ))); } - None => (), }; let query = match code { @@ -467,22 +464,18 @@ impl QueryRouter { inferred_shard: Option, prev_inferred_shard: &mut Option, ) -> Result<(), Error> { - match inferred_shard { - Some(shard) => { - if let Some(prev_shard) = *prev_inferred_shard { - if prev_shard != shard { - debug!("Found more than one shard in the query, not supported yet"); - return Err(Error::QueryRouterParserError( - "multiple shards in query".into(), - )); - } + if let Some(shard) = inferred_shard { + if let Some(prev_shard) = *prev_inferred_shard { + if prev_shard != shard { + debug!("Found more than one shard in the query, not supported yet"); + return Err(Error::QueryRouterParserError( + "multiple shards in query".into(), + )); } - *prev_inferred_shard = Some(shard); - self.active_shard = Some(shard); - debug!("Automatically using shard: {:?}", self.active_shard); } - - None => (), + *prev_inferred_shard = Some(shard); + self.active_shard = Some(shard); + debug!("Automatically using shard: {:?}", self.active_shard); }; Ok(()) } @@ -513,7 +506,7 @@ impl QueryRouter { assert!(after_columns.is_empty()); Self::process_table(table_name, &mut table_names); - Self::process_query(&*source, &mut exprs, &mut table_names, &Some(columns)); + Self::process_query(source, &mut exprs, &mut table_names, &Some(columns)); } Delete { tables, @@ -529,7 +522,7 @@ impl QueryRouter { // Multi tables delete are not supported in postgres. assert!(tables.is_empty()); - Self::process_tables_with_join(&from, &mut exprs, &mut table_names); + Self::process_tables_with_join(from, &mut exprs, &mut table_names); if let Some(using_tbl_with_join) = using { Self::process_tables_with_join( using_tbl_with_join, @@ -569,7 +562,7 @@ impl QueryRouter { ) { match &*query.body { SetExpr::Query(query) => { - Self::process_query(&*query, exprs, table_names, columns); + Self::process_query(query, exprs, table_names, columns); } // SELECT * FROM ... @@ -611,7 +604,7 @@ impl QueryRouter { } fn process_tables_with_join( - tables: &Vec, + tables: &[TableWithJoins], exprs: &mut Vec, table_names: &mut Vec>, ) { @@ -625,37 +618,21 @@ impl QueryRouter { exprs: &mut Vec, table_names: &mut Vec>, ) { - match &table.relation { - TableFactor::Table { name, .. } => { - Self::process_table(name, table_names); - } - - _ => (), + if let TableFactor::Table { name, .. } = &table.relation { + Self::process_table(name, table_names); }; // Get table names from all the joins. for join in table.joins.iter() { - match &join.relation { - TableFactor::Table { name, .. } => { - Self::process_table(name, table_names); - } - - _ => (), + if let TableFactor::Table { name, .. } = &join.relation { + Self::process_table(name, table_names); }; // We can filter results based on join conditions, e.g. // SELECT * FROM t INNER JOIN B ON B.sharding_key = 5; - match &join.join_operator { - JoinOperator::Inner(inner_join) => match &inner_join { - JoinConstraint::On(expr) => { - // Parse the selection criteria later. - exprs.push(expr.clone()); - } - - _ => (), - }, - - _ => (), + if let JoinOperator::Inner(JoinConstraint::On(expr)) = &join.join_operator { + // Parse the selection criteria later. + exprs.push(expr.clone()); }; } } @@ -814,7 +791,7 @@ impl QueryRouter { .automatic_sharding_key .as_ref() .unwrap() - .split(".") + .split('.') .map(|ident| Ident::new(ident.to_lowercase())) .collect::>(); @@ -822,12 +799,12 @@ impl QueryRouter { assert_eq!(sharding_key.len(), 2); for a in assignments { - if sharding_key[0].value == "*" { - if sharding_key[1].value == a.id.last().unwrap().value.to_lowercase() { - return Err(Error::QueryRouterParserError( - "Sharding key cannot be updated.".into(), - )); - } + if sharding_key[0].value == "*" + && sharding_key[1].value == a.id.last().unwrap().value.to_lowercase() + { + return Err(Error::QueryRouterParserError( + "Sharding key cannot be updated.".into(), + )); } } Ok(()) @@ -844,7 +821,7 @@ impl QueryRouter { .automatic_sharding_key .as_ref() .unwrap() - .split(".") + .split('.') .map(|ident| Ident::new(ident.to_lowercase())) .collect::>(); @@ -861,7 +838,7 @@ impl QueryRouter { Expr::Identifier(ident) => { // Only if we're dealing with only one table // and there is no ambiguity - if &ident.value.to_lowercase() == &sharding_key[1].value { + if ident.value.to_lowercase() == sharding_key[1].value { // Sharding key is unique enough, don't worry about // table names. if &sharding_key[0].value == "*" { @@ -874,13 +851,13 @@ impl QueryRouter { // SELECT * FROM t WHERE sharding_key = 5 // Make sure the table name from the sharding key matches // the table name from the query. - found = &sharding_key[0].value == &table[0].value.to_lowercase(); + found = sharding_key[0].value == table[0].value.to_lowercase(); } else if table.len() == 2 { // Table name is fully qualified with the schema: e.g. // SELECT * FROM public.t WHERE sharding_key = 5 // Ignore the schema (TODO: at some point, we want schema support) // and use the table name only. - found = &sharding_key[0].value == &table[1].value.to_lowercase(); + found = sharding_key[0].value == table[1].value.to_lowercase(); } else { debug!("Got table name with more than two idents, which is not possible"); } @@ -893,8 +870,8 @@ impl QueryRouter { // it will exist or Postgres will throw an error. if idents.len() == 2 { found = (&sharding_key[0].value == "*" - || &sharding_key[0].value == &idents[0].value.to_lowercase()) - && &sharding_key[1].value == &idents[1].value.to_lowercase(); + || sharding_key[0].value == idents[0].value.to_lowercase()) + && sharding_key[1].value == idents[1].value.to_lowercase(); } // TODO: key can have schema as well, e.g. public.data.id (len == 3) } @@ -926,7 +903,7 @@ impl QueryRouter { } Expr::Value(Value::Placeholder(placeholder)) => { - match placeholder.replace("$", "").parse::() { + match placeholder.replace('$', "").parse::() { Ok(placeholder) => result.push(ShardingKey::Placeholder(placeholder)), Err(_) => { debug!( @@ -1020,16 +997,16 @@ impl QueryRouter { db: &self.pool_settings.db, }; - let _ = query_logger.run(&self, ast).await; + let _ = query_logger.run(self, ast).await; } if let Some(ref intercept) = plugins.intercept { let mut intercept = Intercept { enabled: intercept.enabled, - config: &intercept, + config: intercept, }; - let result = intercept.run(&self, ast).await; + let result = intercept.run(self, ast).await; if let Ok(PluginOutput::Intercept(output)) = result { return Ok(PluginOutput::Intercept(output)); @@ -1042,7 +1019,7 @@ impl QueryRouter { tables: &table_access.tables, }; - let result = table_access.run(&self, ast).await; + let result = table_access.run(self, ast).await; if let Ok(PluginOutput::Deny(error)) = result { return Ok(PluginOutput::Deny(error)); @@ -1078,7 +1055,7 @@ impl QueryRouter { /// Should we attempt to parse queries? pub fn query_parser_enabled(&self) -> bool { - let enabled = match self.query_parser_enabled { + match self.query_parser_enabled { None => { debug!( "Using pool settings, query_parser_enabled: {}", @@ -1094,9 +1071,7 @@ impl QueryRouter { ); value } - }; - - enabled + } } pub fn primary_reads_enabled(&self) -> bool { @@ -1107,6 +1082,12 @@ impl QueryRouter { } } +impl Default for QueryRouter { + fn default() -> Self { + Self::new() + } +} + #[cfg(test)] mod test { use super::*; @@ -1128,10 +1109,14 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); qr.pool_settings.query_parser_read_write_splitting = true; - assert!(qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) != None); + assert!(qr + .try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) + .is_some()); assert!(qr.query_parser_enabled()); - assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); + assert!(qr + .try_execute_command(&simple_query("SET PRIMARY READS TO off")) + .is_some()); let queries = vec![ simple_query("SELECT * FROM items WHERE id = 5"), @@ -1173,7 +1158,9 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); let query = simple_query("SELECT * FROM items WHERE id = 5"); - assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None); + assert!(qr + .try_execute_command(&simple_query("SET PRIMARY READS TO on")) + .is_some()); assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok()); assert_eq!(qr.role(), None); @@ -1186,7 +1173,9 @@ mod test { qr.pool_settings.query_parser_read_write_splitting = true; qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")); - assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); + assert!(qr + .try_execute_command(&simple_query("SET PRIMARY READS TO off")) + .is_some()); let prepared_stmt = BytesMut::from( &b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..], @@ -1356,9 +1345,11 @@ mod test { qr.pool_settings.query_parser_read_write_splitting = true; let query = simple_query("SET SERVER ROLE TO 'auto'"); - assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); + assert!(qr + .try_execute_command(&simple_query("SET PRIMARY READS TO off")) + .is_some()); - assert!(qr.try_execute_command(&query) != None); + assert!(qr.try_execute_command(&query).is_some()); assert!(qr.query_parser_enabled()); assert_eq!(qr.role(), None); @@ -1372,7 +1363,7 @@ mod test { assert!(qr.query_parser_enabled()); let query = simple_query("SET SERVER ROLE TO 'default'"); - assert!(qr.try_execute_command(&query) != None); + assert!(qr.try_execute_command(&query).is_some()); assert!(!qr.query_parser_enabled()); } @@ -1420,11 +1411,11 @@ mod test { assert!(!qr.primary_reads_enabled()); let q1 = simple_query("SET SERVER ROLE TO 'primary'"); - assert!(qr.try_execute_command(&q1) != None); + assert!(qr.try_execute_command(&q1).is_some()); assert_eq!(qr.active_role.unwrap(), Role::Primary); let q2 = simple_query("SET SERVER ROLE TO 'default'"); - assert!(qr.try_execute_command(&q2) != None); + assert!(qr.try_execute_command(&q2).is_some()); assert_eq!(qr.active_role.unwrap(), pool_settings.default_role); } @@ -1485,29 +1476,29 @@ mod test { }; let mut qr = QueryRouter::new(); - qr.update_pool_settings(pool_settings.clone()); + qr.update_pool_settings(pool_settings); // Shard should start out unset assert_eq!(qr.active_shard, None); // Don't panic when short query eg. ; is sent let q0 = simple_query(";"); - assert!(qr.try_execute_command(&q0) == None); + assert!(qr.try_execute_command(&q0).is_none()); assert_eq!(qr.active_shard, None); // Make sure setting it works let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;"); - assert!(qr.try_execute_command(&q1) == None); + assert!(qr.try_execute_command(&q1).is_none()); assert_eq!(qr.active_shard, Some(1)); // And make sure changing it works let q2 = simple_query("/* shard_id: 0 */ select 1 from foo;"); - assert!(qr.try_execute_command(&q2) == None); + assert!(qr.try_execute_command(&q2).is_none()); assert_eq!(qr.active_shard, Some(0)); // Validate setting by shard with expected shard copied from sharding.rs tests let q2 = simple_query("/* sharding_key: 6 */ select 1 from foo;"); - assert!(qr.try_execute_command(&q2) == None); + assert!(qr.try_execute_command(&q2).is_none()); assert_eq!(qr.active_shard, Some(2)); } @@ -1863,10 +1854,11 @@ mod test { }; QueryRouter::setup(); - let mut pool_settings = PoolSettings::default(); - pool_settings.query_parser_enabled = true; - pool_settings.plugins = Some(plugins); - + let pool_settings = PoolSettings { + query_parser_enabled: true, + plugins: Some(plugins), + ..Default::default() + }; let mut qr = QueryRouter::new(); qr.update_pool_settings(pool_settings); diff --git a/src/scram.rs b/src/scram.rs index 3e5d8470..111dd5e1 100644 --- a/src/scram.rs +++ b/src/scram.rs @@ -79,12 +79,12 @@ impl ScramSha256 { let server_message = Message::parse(message)?; if !server_message.nonce.starts_with(&self.nonce) { - return Err(Error::ProtocolSyncError(format!("SCRAM"))); + return Err(Error::ProtocolSyncError("SCRAM".to_string())); } let salt = match general_purpose::STANDARD.decode(&server_message.salt) { Ok(salt) => salt, - Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))), + Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())), }; let salted_password = Self::hi( @@ -166,9 +166,9 @@ impl ScramSha256 { pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> { let final_message = FinalMessage::parse(message)?; - let verifier = match general_purpose::STANDARD.decode(&final_message.value) { + let verifier = match general_purpose::STANDARD.decode(final_message.value) { Ok(verifier) => verifier, - Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))), + Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())), }; let mut hmac = match Hmac::::new_from_slice(&self.salted_password) { @@ -230,14 +230,14 @@ impl Message { .collect::>(); if parts.len() != 3 { - return Err(Error::ProtocolSyncError(format!("SCRAM"))); + return Err(Error::ProtocolSyncError("SCRAM".to_string())); } let nonce = str::replace(&parts[0], "r=", ""); let salt = str::replace(&parts[1], "s=", ""); let iterations = match str::replace(&parts[2], "i=", "").parse::() { Ok(iterations) => iterations, - Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))), + Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())), }; Ok(Message { @@ -257,7 +257,7 @@ impl FinalMessage { /// Parse the server final validation message. pub fn parse(message: &BytesMut) -> Result { if !message.starts_with(b"v=") || message.len() < 4 { - return Err(Error::ProtocolSyncError(format!("SCRAM"))); + return Err(Error::ProtocolSyncError("SCRAM".to_string())); } Ok(FinalMessage { diff --git a/src/server.rs b/src/server.rs index 70c8270d..5ccafeeb 100644 --- a/src/server.rs +++ b/src/server.rs @@ -197,12 +197,8 @@ impl ServerParameters { key = "DateStyle".to_string(); }; - if TRACKED_PARAMETERS.contains(&key) { + if TRACKED_PARAMETERS.contains(&key) || startup { self.parameters.insert(key, value); - } else { - if startup { - self.parameters.insert(key, value); - } } } @@ -332,6 +328,7 @@ pub struct Server { impl Server { /// Pretend to be the Postgres client and connect to the server given host, port and credentials. /// Perform the authentication and return the server in a ready for query state. + #[allow(clippy::too_many_arguments)] pub async fn startup( address: &Address, user: &User, @@ -440,10 +437,7 @@ impl Server { // Something else? m => { - return Err(Error::SocketError(format!( - "Unknown message: {}", - m as char - ))); + return Err(Error::SocketError(format!("Unknown message: {}", { m }))); } } } else { @@ -461,6 +455,8 @@ impl Server { None => &user.username, }; + #[allow(clippy::match_as_ref)] + #[allow(clippy::manual_map)] let password = match user.server_password { Some(ref server_password) => Some(server_password), None => match user.password { @@ -473,14 +469,11 @@ impl Server { let mut process_id: i32 = 0; let mut secret_key: i32 = 0; - let server_identifier = ServerIdentifier::new(username, &database); + let server_identifier = ServerIdentifier::new(username, database); // We'll be handling multiple packets, but they will all be structured the same. // We'll loop here until this exchange is complete. - let mut scram: Option = match password { - Some(password) => Some(ScramSha256::new(password)), - None => None, - }; + let mut scram: Option = password.map(|password| ScramSha256::new(password)); let mut server_parameters = ServerParameters::new(); @@ -882,7 +875,7 @@ impl Server { self.mirror_send(messages); self.stats().data_sent(messages.len()); - match write_all_flush(&mut self.stream, &messages).await { + match write_all_flush(&mut self.stream, messages).await { Ok(_) => { // Successfully sent to server self.last_activity = SystemTime::now(); @@ -1359,16 +1352,14 @@ impl Server { } pub fn mirror_send(&mut self, bytes: &BytesMut) { - match self.mirror_manager.as_mut() { - Some(manager) => manager.send(bytes), - None => (), + if let Some(manager) = self.mirror_manager.as_mut() { + manager.send(bytes) } } pub fn mirror_disconnect(&mut self) { - match self.mirror_manager.as_mut() { - Some(manager) => manager.disconnect(), - None => (), + if let Some(manager) = self.mirror_manager.as_mut() { + manager.disconnect() } } @@ -1397,7 +1388,7 @@ impl Server { server.send(&simple_query(query)).await?; let mut message = server.recv(None).await?; - Ok(parse_query_message(&mut message).await?) + parse_query_message(&mut message).await } } diff --git a/src/sharding.rs b/src/sharding.rs index 18581dcf..a7a9df13 100644 --- a/src/sharding.rs +++ b/src/sharding.rs @@ -64,7 +64,7 @@ impl Sharder { fn sha1(&self, key: i64) -> usize { let mut hasher = Sha1::new(); - hasher.update(&key.to_string().as_bytes()); + hasher.update(key.to_string().as_bytes()); let result = hasher.finalize(); @@ -202,10 +202,10 @@ mod test { #[test] fn test_sha1_hash() { let sharder = Sharder::new(12, ShardingFunction::Sha1); - let ids = vec![ + let ids = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, ]; - let shards = vec![ + let shards = [ 4, 7, 8, 3, 6, 0, 0, 10, 3, 11, 1, 7, 4, 4, 11, 2, 5, 0, 8, 3, ]; diff --git a/src/stats/pool.rs b/src/stats/pool.rs index d3ac78e9..46c74632 100644 --- a/src/stats/pool.rs +++ b/src/stats/pool.rs @@ -86,11 +86,11 @@ impl PoolStats { } } - return map; + map } pub fn generate_header() -> Vec<(&'static str, DataType)> { - return vec![ + vec![ ("database", DataType::Text), ("user", DataType::Text), ("pool_mode", DataType::Text), @@ -105,11 +105,11 @@ impl PoolStats { ("sv_login", DataType::Numeric), ("maxwait", DataType::Numeric), ("maxwait_us", DataType::Numeric), - ]; + ] } pub fn generate_row(&self) -> Vec { - return vec![ + vec![ self.identifier.db.clone(), self.identifier.user.clone(), self.mode.to_string(), @@ -124,7 +124,7 @@ impl PoolStats { self.sv_login.to_string(), (self.maxwait / 1_000_000).to_string(), (self.maxwait % 1_000_000).to_string(), - ]; + ] } }