Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(mysql): Close prepared statement if persistence is disabled #2905

Merged
merged 2 commits into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 61 additions & 29 deletions sqlx-mysql/src/connection/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,10 @@ use futures_util::{pin_mut, TryStreamExt};
use std::{borrow::Cow, sync::Arc};

impl MySqlConnection {
async fn get_or_prepare<'c>(
async fn prepare_statement<'c>(
&mut self,
sql: &str,
persistent: bool,
) -> Result<(u32, MySqlStatementMetadata), Error> {
if let Some(statement) = self.cache_statement.get_mut(sql) {
// <MySqlStatementMetadata> is internally reference-counted
return Ok((*statement).clone());
}

// https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html
// https://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html#packet-COM_STMT_PREPARE_OK

Expand Down Expand Up @@ -72,11 +66,23 @@ impl MySqlConnection {
column_names: Arc::new(column_names),
};

if persistent && self.cache_statement.is_enabled() {
// in case of the cache being full, close the least recently used statement
if let Some((id, _)) = self.cache_statement.insert(sql, (id, metadata.clone())) {
self.stream.send_packet(StmtClose { statement: id }).await?;
}
Ok((id, metadata))
}

async fn get_or_prepare_statement<'c>(
&mut self,
sql: &str,
) -> Result<(u32, MySqlStatementMetadata), Error> {
if let Some(statement) = self.cache_statement.get_mut(sql) {
// <MySqlStatementMetadata> is internally reference-counted
return Ok((*statement).clone());
}

let (id, metadata) = self.prepare_statement(sql).await?;

// in case of the cache being full, close the least recently used statement
if let Some((id, _)) = self.cache_statement.insert(sql, (id, metadata.clone())) {
self.stream.send_packet(StmtClose { statement: id }).await?;
}

Ok((id, metadata))
Expand All @@ -102,21 +108,37 @@ impl MySqlConnection {
let mut columns = Arc::new(Vec::new());

let (mut column_names, format, mut needs_metadata) = if let Some(arguments) = arguments {
let (id, metadata) = self.get_or_prepare(
sql,
persistent,
)
.await?;

// https://dev.mysql.com/doc/internals/en/com-stmt-execute.html
self.stream
.send_packet(StatementExecute {
statement: id,
arguments: &arguments,
})
.await?;

(metadata.column_names, MySqlValueFormat::Binary, false)
if persistent && self.cache_statement.is_enabled() {
let (id, metadata) = self
.get_or_prepare_statement(sql)
.await?;

// https://dev.mysql.com/doc/internals/en/com-stmt-execute.html
self.stream
.send_packet(StatementExecute {
statement: id,
arguments: &arguments,
})
.await?;

(metadata.column_names, MySqlValueFormat::Binary, false)
} else {
let (id, metadata) = self
.prepare_statement(sql)
.await?;

// https://dev.mysql.com/doc/internals/en/com-stmt-execute.html
self.stream
.send_packet(StatementExecute {
statement: id,
arguments: &arguments,
})
.await?;

self.stream.send_packet(StmtClose { statement: id }).await?;

(metadata.column_names, MySqlValueFormat::Binary, false)
}
} else {
// https://dev.mysql.com/doc/internals/en/com-query.html
self.stream.send_packet(Query(sql)).await?;
Expand Down Expand Up @@ -269,7 +291,15 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection {
Box::pin(async move {
self.stream.wait_until_ready().await?;

let (_, metadata) = self.get_or_prepare(sql, true).await?;
let metadata = if self.cache_statement.is_enabled() {
self.get_or_prepare_statement(sql).await?.1
} else {
let (id, metadata) = self.prepare_statement(sql).await?;

self.stream.send_packet(StmtClose { statement: id }).await?;

metadata
};

Ok(MySqlStatement {
sql: Cow::Borrowed(sql),
Expand All @@ -287,7 +317,9 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection {
Box::pin(async move {
self.stream.wait_until_ready().await?;

let (_, metadata) = self.get_or_prepare(sql, false).await?;
let (id, metadata) = self.prepare_statement(sql).await?;

self.stream.send_packet(StmtClose { statement: id }).await?;

let columns = (&*metadata.columns).clone();

Expand Down
66 changes: 66 additions & 0 deletions tests/mysql/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,57 @@ async fn it_caches_statements() -> anyhow::Result<()> {
Ok(())
}

#[sqlx_macros::test]
async fn it_closes_statements_with_persistent_disabled() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;

let old_statement_count = select_statement_count(&mut conn).await.unwrap_or_default();

for i in 0..2 {
let row = sqlx::query("SELECT ? AS val")
.bind(i)
.persistent(false)
.fetch_one(&mut conn)
.await?;

let val: i32 = row.get("val");

assert_eq!(i, val);
}

let new_statement_count = select_statement_count(&mut conn).await.unwrap_or_default();

assert_eq!(old_statement_count, new_statement_count);

Ok(())
}

#[sqlx_macros::test]
async fn it_closes_statements_with_cache_disabled() -> anyhow::Result<()> {
setup_if_needed();

let mut url = url::Url::parse(&env::var("DATABASE_URL")?)?;
url.query_pairs_mut()
.append_pair("statement-cache-capacity", "0");

let mut conn = MySqlConnection::connect(url.as_ref()).await?;

let old_statement_count = select_statement_count(&mut conn).await.unwrap_or_default();

for index in 1..=10_i32 {
let _ = sqlx::query("SELECT ?")
.bind(index)
.execute(&mut conn)
.await?;
}

let new_statement_count = select_statement_count(&mut conn).await.unwrap_or_default();

assert_eq!(old_statement_count, new_statement_count);

Ok(())
}

#[sqlx_macros::test]
async fn it_can_bind_null_and_non_null_issue_540() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;
Expand Down Expand Up @@ -510,3 +561,18 @@ async fn test_shrink_buffers() -> anyhow::Result<()> {

Ok(())
}

async fn select_statement_count(conn: &mut MySqlConnection) -> Result<i64, sqlx::Error> {
// Fails if performance schema does not exist
sqlx::query_scalar(
r#"
SELECT COUNT(*)
FROM performance_schema.threads AS t
INNER JOIN performance_schema.prepared_statements_instances AS psi
ON psi.OWNER_THREAD_ID = t.THREAD_ID
WHERE t.processlist_id = CONNECTION_ID()
"#,
)
.fetch_one(conn)
.await
}