diff --git a/sqlx-core/src/common/statement_cache.rs b/sqlx-core/src/common/statement_cache.rs index e87785eb09..7ba19a610b 100644 --- a/sqlx-core/src/common/statement_cache.rs +++ b/sqlx-core/src/common/statement_cache.rs @@ -47,4 +47,9 @@ impl StatementCache { pub fn remove_lru(&mut self) -> Option { self.inner.remove_lru().map(|(_, v)| v) } + + /// Clear all cached statements from the cache. + pub fn clear(&mut self) { + self.inner.clear(); + } } diff --git a/sqlx-core/src/postgres/connection/establish.rs b/sqlx-core/src/postgres/connection/establish.rs index 7b160ad16d..a0438e6f3c 100644 --- a/sqlx-core/src/postgres/connection/establish.rs +++ b/sqlx-core/src/postgres/connection/establish.rs @@ -1,5 +1,6 @@ use hashbrown::HashMap; +use crate::common::StatementCache; use crate::error::Error; use crate::io::Decode; use crate::postgres::connection::{sasl, stream::PgStream, tls}; @@ -138,7 +139,7 @@ impl PgConnection { transaction_status, pending_ready_for_query_count: 0, next_statement_id: 1, - cache_statement: HashMap::with_capacity(10), + cache_statement: StatementCache::new(options.statement_cache_size), cache_type_oid: HashMap::new(), cache_type_info: HashMap::new(), scratch_row_columns: Default::default(), diff --git a/sqlx-core/src/postgres/connection/executor.rs b/sqlx-core/src/postgres/connection/executor.rs index ed58ed56fa..b3fc706767 100644 --- a/sqlx-core/src/postgres/connection/executor.rs +++ b/sqlx-core/src/postgres/connection/executor.rs @@ -88,15 +88,16 @@ async fn recv_desc_rows(conn: &mut PgConnection) -> Result Result { - if let Some(statement) = self.cache_statement.get(query) { + if let Some(statement) = self.cache_statement.get_mut(query) { return Ok(*statement); } let statement = prepare(self, query, arguments).await?; - self.cache_statement.insert(query.to_owned(), statement); + self.cache_statement.insert(query, statement); Ok(statement) } diff --git a/sqlx-core/src/postgres/connection/mod.rs b/sqlx-core/src/postgres/connection/mod.rs index e3d1e22066..987769dfdc 100644 --- a/sqlx-core/src/postgres/connection/mod.rs +++ b/sqlx-core/src/postgres/connection/mod.rs @@ -5,6 +5,8 @@ use futures_core::future::BoxFuture; use futures_util::{FutureExt, TryFutureExt}; use hashbrown::HashMap; +use crate::caching_connection::CachingConnection; +use crate::common::StatementCache; use crate::connection::{Connect, Connection}; use crate::error::Error; use crate::executor::Executor; @@ -46,7 +48,7 @@ pub struct PgConnection { next_statement_id: u32, // cache statement by query string to the id and columns - cache_statement: HashMap, + cache_statement: StatementCache, // cache user-defined types by id <-> info cache_type_info: HashMap, @@ -96,6 +98,19 @@ impl Debug for PgConnection { } } +impl CachingConnection for PgConnection { + fn cached_statements_count(&self) -> usize { + self.cache_statement.len() + } + + fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(async move { + self.cache_statement.clear(); + Ok(()) + }) + } +} + impl Connection for PgConnection { type Database = Postgres; diff --git a/sqlx-core/src/postgres/options.rs b/sqlx-core/src/postgres/options.rs index b59a332f53..35971adea2 100644 --- a/sqlx-core/src/postgres/options.rs +++ b/sqlx-core/src/postgres/options.rs @@ -115,6 +115,7 @@ pub struct PgConnectOptions { pub(crate) database: Option, pub(crate) ssl_mode: PgSslMode, pub(crate) ssl_root_cert: Option, + pub(crate) statement_cache_size: usize, } impl Default for PgConnectOptions { @@ -162,6 +163,7 @@ impl PgConnectOptions { .ok() .and_then(|v| v.parse().ok()) .unwrap_or_default(), + statement_cache_size: 100, } } @@ -285,6 +287,17 @@ impl PgConnectOptions { self.ssl_root_cert = Some(cert.as_ref().to_path_buf()); self } + + /// Sets the size of the connection's statement cache in a number of stored + /// distinct statements. Caching is handled using LRU, meaning when the + /// amount of queries hits the defined limit, the oldest statement will get + /// dropped. + /// + /// The default cache size is 100 statements. + pub fn statement_cache_size(mut self, size: usize) -> Self { + self.statement_cache_size = size; + self + } } fn default_host(port: u16) -> String { diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 2aeea43dbc..b7a40cf45c 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -1,6 +1,7 @@ use futures::TryStreamExt; use sqlx::postgres::PgRow; use sqlx::postgres::{PgDatabaseError, PgErrorPosition, PgSeverity}; +use sqlx::CachingConnection; use sqlx::{postgres::Postgres, Connection, Executor, PgPool, Row}; use sqlx_test::new; use std::time::Duration; @@ -487,3 +488,25 @@ SELECT id, text FROM _sqlx_test_postgres_5112; Ok(()) } + +#[sqlx_macros::test] +async fn it_caches_statements() -> anyhow::Result<()> { + let mut conn = new::().await?; + + for i in 0..2 { + let row = sqlx::query("SELECT $1 AS val") + .bind(i) + .fetch_one(&mut conn) + .await?; + + let val: u32 = row.get("val"); + + assert_eq!(i, val); + } + + assert_eq!(1, conn.cached_statements_count()); + conn.clear_cached_statements().await?; + assert_eq!(0, conn.cached_statements_count()); + + Ok(()) +}