diff --git a/sqlx-store/src/postgres_store.rs b/sqlx-store/src/postgres_store.rs index 13f55a2..91cea9a 100644 --- a/sqlx-store/src/postgres_store.rs +++ b/sqlx-store/src/postgres_store.rs @@ -35,6 +35,36 @@ impl PostgresStore { } } + /// Set the session table schema name with the provided name. + pub fn with_schema_name(mut self, schema_name: impl AsRef) -> Result { + let schema_name = schema_name.as_ref(); + if !is_valid_identifier(schema_name) { + return Err(format!( + "Invalid schema name '{}'. Schema names must start with a letter or underscore (including letters with diacritical marks and non-Latin letters).\ + Subsequent characters can be letters, underscores, digits (0-9), or dollar signs ($).", + schema_name + )); + } + + self.schema_name = schema_name.to_owned(); + Ok(self) + } + + /// Set the session table name with the provided name. + pub fn with_table_name(mut self, table_name: impl AsRef) -> Result { + let table_name = table_name.as_ref(); + if !is_valid_identifier(table_name) { + return Err(format!( + "Invalid table name '{}'. Table names must start with a letter or underscore (including letters with diacritical marks and non-Latin letters).\ + Subsequent characters can be letters, underscores, digits (0-9), or dollar signs ($).", + table_name + )); + } + + self.table_name = table_name.to_owned(); + Ok(self) + } + /// Migrate the session schema. /// /// # Examples @@ -169,3 +199,18 @@ impl SessionStore for PostgresStore { Ok(()) } } + +/// A valid PostreSQL identifier must start with a letter or underscore (including letters with diacritical marks and non-Latin letters). +/// Subsequent characters in an identifier or key word can be letters, underscores, digits (0-9), or dollar signs ($). +/// See https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS for details. +fn is_valid_identifier(name: &str) -> bool { + !name.is_empty() + && name + .chars() + .next() + .map(|c| c.is_alphabetic() || c == '_') + .unwrap_or_default() + && name + .chars() + .all(|c| c.is_alphanumeric() || c == '_' || c == '$') +}