diff --git a/sqlx-core/src/postgres/types/lquery.rs b/sqlx-core/src/postgres/types/lquery.rs new file mode 100644 index 0000000000..ddb05bd43e --- /dev/null +++ b/sqlx-core/src/postgres/types/lquery.rs @@ -0,0 +1,313 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use crate::types::Type; +use bitflags::bitflags; +use std::fmt::{self, Display, Formatter}; +use std::io::Write; +use std::ops::Deref; +use std::str::FromStr; + +use crate::postgres::types::ltree::{PgLTreeLabel, PgLTreeParseError}; + +/// Represents lquery specific errors +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum PgLQueryParseError { + #[error("lquery cannot be empty")] + EmptyString, + #[error("unexpected character in lquery")] + UnexpectedCharacter, + #[error("error parsing integer: {0}")] + ParseIntError(#[from] std::num::ParseIntError), + #[error("error parsing integer: {0}")] + LTreeParrseError(#[from] PgLTreeParseError), + /// LQuery version not supported + #[error("lquery version not supported")] + InvalidLqueryVersion, +} + +/// Container for a Label Tree Query (`lquery`) in Postgres. +/// +/// See https://www.postgresql.org/docs/current/ltree.html +/// +/// ### Note: Requires Postgres 13+ +/// +/// This integration requires that the `lquery` type support the binary format in the Postgres +/// wire protocol, which only became available in Postgres 13. +/// ([Postgres 13.0 Release Notes, Additional Modules][https://www.postgresql.org/docs/13/release-13.html#id-1.11.6.11.5.14]) +/// +/// Ideally, SQLx's Postgres driver should support falling back to text format for types +/// which don't have `typsend` and `typrecv` entries in `pg_type`, but that work still needs +/// to be done. +/// +/// ### Note: Extension Required +/// The `ltree` extension is not enabled by default in Postgres. You will need to do so explicitly: +/// +/// ```ignore +/// CREATE EXTENSION IF NOT EXISTS "ltree"; +/// ``` +#[derive(Clone, Debug, Default, PartialEq)] +pub struct PgLQuery { + levels: Vec, +} + +// TODO: maybe a QueryBuilder pattern would be nice here +impl PgLQuery { + /// creates default/empty lquery + pub fn new() -> Self { + Self::default() + } + + pub fn from(levels: Vec) -> Self { + Self { levels } + } + + /// push a query level + pub fn push(&mut self, level: PgLQueryLevel) { + self.levels.push(level); + } + + /// pop a query level + pub fn pop(&mut self) -> Option { + self.levels.pop() + } + + /// creates lquery from an iterator with checking labels + pub fn from_iter(levels: I) -> Result + where + S: Into, + I: IntoIterator, + { + let mut lquery = Self::default(); + for level in levels { + lquery.push(PgLQueryLevel::from_str(&level.into())?); + } + Ok(lquery) + } +} + +impl IntoIterator for PgLQuery { + type Item = PgLQueryLevel; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.levels.into_iter() + } +} + +impl FromStr for PgLQuery { + type Err = PgLQueryParseError; + + fn from_str(s: &str) -> Result { + Ok(Self { + levels: s + .split('.') + .map(|s| PgLQueryLevel::from_str(s)) + .collect::>()?, + }) + } +} + +impl Display for PgLQuery { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut iter = self.levels.iter(); + if let Some(label) = iter.next() { + write!(f, "{}", label)?; + for label in iter { + write!(f, ".{}", label)?; + } + } + Ok(()) + } +} + +impl Deref for PgLQuery { + type Target = [PgLQueryLevel]; + + fn deref(&self) -> &Self::Target { + &self.levels + } +} + +impl Type for PgLQuery { + fn type_info() -> PgTypeInfo { + // Since `ltree` is enabled by an extension, it does not have a stable OID. + PgTypeInfo::with_name("lquery") + } +} + +impl Encode<'_, Postgres> for PgLQuery { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + buf.extend(1i8.to_le_bytes()); + write!(buf, "{}", self) + .expect("Display implementation panicked while writing to PgArgumentBuffer"); + + IsNull::No + } +} + +impl<'r> Decode<'r, Postgres> for PgLQuery { + fn decode(value: PgValueRef<'r>) -> Result { + match value.format() { + PgValueFormat::Binary => { + let bytes = value.as_bytes()?; + let version = i8::from_le_bytes([bytes[0]; 1]); + if version != 1 { + return Err(Box::new(PgLQueryParseError::InvalidLqueryVersion)); + } + Ok(Self::from_str(std::str::from_utf8(&bytes[1..])?)?) + } + PgValueFormat::Text => Ok(Self::from_str(value.as_str()?)?), + } + } +} + +bitflags! { + /// Modifiers that can be set to non-star labels + pub struct PgLQueryVariantFlag: u16 { + /// * - Match any label with this prefix, for example foo* matches foobar + const ANY_END = 0x01; + /// @ - Match case-insensitively, for example a@ matches A + const IN_CASE = 0x02; + /// % - Match initial underscore-separated words + const SUBLEXEME = 0x04; + } +} + +impl Display for PgLQueryVariantFlag { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + if self.contains(PgLQueryVariantFlag::ANY_END) { + write!(f, "*")?; + } + if self.contains(PgLQueryVariantFlag::IN_CASE) { + write!(f, "@")?; + } + if self.contains(PgLQueryVariantFlag::SUBLEXEME) { + write!(f, "%")?; + } + + Ok(()) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct PgLQueryVariant { + label: PgLTreeLabel, + modifiers: PgLQueryVariantFlag, +} + +impl Display for PgLQueryVariant { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}{}", self.label, self.modifiers) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub enum PgLQueryLevel { + /// match any label (*) with optional at least / at most numbers + Star(Option, Option), + /// match any of specified labels with optional flags + NonStar(Vec), + /// match none of specified labels with optional flags + NotNonStar(Vec), +} + +impl FromStr for PgLQueryLevel { + type Err = PgLQueryParseError; + + fn from_str(s: &str) -> Result { + let bytes = s.as_bytes(); + if bytes.is_empty() { + Err(PgLQueryParseError::EmptyString) + } else { + match bytes[0] { + b'*' => { + if bytes.len() > 1 { + let parts = s[2..s.len() - 1].split(',').collect::>(); + match parts.len() { + 1 => { + let number = parts[0].parse()?; + Ok(PgLQueryLevel::Star(Some(number), Some(number))) + } + 2 => Ok(PgLQueryLevel::Star( + Some(parts[0].parse()?), + Some(parts[1].parse()?), + )), + _ => Err(PgLQueryParseError::UnexpectedCharacter), + } + } else { + Ok(PgLQueryLevel::Star(None, None)) + } + } + b'!' => Ok(PgLQueryLevel::NotNonStar( + s[1..] + .split('|') + .map(|s| PgLQueryVariant::from_str(s)) + .collect::, PgLQueryParseError>>()?, + )), + _ => Ok(PgLQueryLevel::NonStar( + s.split('|') + .map(|s| PgLQueryVariant::from_str(s)) + .collect::, PgLQueryParseError>>()?, + )), + } + } + } +} + +impl FromStr for PgLQueryVariant { + type Err = PgLQueryParseError; + + fn from_str(s: &str) -> Result { + let mut label_length = s.len(); + let mut rev_iter = s.bytes().rev(); + let mut modifiers = PgLQueryVariantFlag { bits: 0 }; + + while let Some(b) = rev_iter.next() { + match b { + b'@' => modifiers.insert(PgLQueryVariantFlag::IN_CASE), + b'*' => modifiers.insert(PgLQueryVariantFlag::ANY_END), + b'%' => modifiers.insert(PgLQueryVariantFlag::SUBLEXEME), + _ => break, + } + label_length -= 1; + } + + Ok(PgLQueryVariant { + label: PgLTreeLabel::new(&s[0..label_length])?, + modifiers, + }) + } +} + +fn write_variants(f: &mut Formatter<'_>, variants: &[PgLQueryVariant], not: bool) -> fmt::Result { + let mut iter = variants.iter(); + if let Some(variant) = iter.next() { + write!(f, "{}{}", if not { "!" } else { "" }, variant)?; + for variant in iter { + write!(f, ".{}", variant)?; + } + } + Ok(()) +} + +impl Display for PgLQueryLevel { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + PgLQueryLevel::Star(Some(at_least), Some(at_most)) => { + if at_least == at_most { + write!(f, "*{{{}}}", at_least) + } else { + write!(f, "*{{{},{}}}", at_least, at_most) + } + } + PgLQueryLevel::Star(Some(at_least), _) => write!(f, "*{{{},}}", at_least), + PgLQueryLevel::Star(_, Some(at_most)) => write!(f, "*{{,{}}}", at_most), + PgLQueryLevel::Star(_, _) => write!(f, "*"), + PgLQueryLevel::NonStar(variants) => write_variants(f, &variants, false), + PgLQueryLevel::NotNonStar(variants) => write_variants(f, &variants, true), + } + } +} diff --git a/sqlx-core/src/postgres/types/ltree.rs b/sqlx-core/src/postgres/types/ltree.rs index 23173f528e..0e9d7768d0 100644 --- a/sqlx-core/src/postgres/types/ltree.rs +++ b/sqlx-core/src/postgres/types/ltree.rs @@ -23,6 +23,45 @@ pub enum PgLTreeParseError { InvalidLtreeVersion, } +#[derive(Clone, Debug, Default, PartialEq)] +pub struct PgLTreeLabel(String); + +impl PgLTreeLabel { + pub fn new(label: &str) -> Result { + if label.len() <= 256 + && label + .bytes() + .all(|c| c.is_ascii_alphabetic() || c.is_ascii_digit() || c == b'_') + { + Ok(Self(label.to_owned())) + } else { + Err(PgLTreeParseError::InvalidLtreeLabel) + } + } +} + +impl Deref for PgLTreeLabel { + type Target = str; + + fn deref(&self) -> &Self::Target { + self.0.as_str() + } +} + +impl FromStr for PgLTreeLabel { + type Err = PgLTreeParseError; + + fn from_str(s: &str) -> Result { + PgLTreeLabel::new(s) + } +} + +impl Display for PgLTreeLabel { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + /// Container for a Label Tree (`ltree`) in Postgres. /// /// See https://www.postgresql.org/docs/current/ltree.html @@ -45,7 +84,7 @@ pub enum PgLTreeParseError { /// ``` #[derive(Clone, Debug, Default, PartialEq)] pub struct PgLTree { - labels: Vec, + labels: Vec, } impl PgLTree { @@ -54,8 +93,8 @@ impl PgLTree { Self::default() } - /// creates ltree from a [Vec] without checking labels - pub fn new_unchecked(labels: Vec) -> Self { + /// creates ltree from a [Vec] + pub fn from(labels: Vec) -> Self { Self { labels } } @@ -67,33 +106,25 @@ impl PgLTree { { let mut ltree = Self::default(); for label in labels { - ltree.push(label.into())?; + ltree.push(&label.into())?; } Ok(ltree) } /// push a label to ltree - pub fn push(&mut self, label: String) -> Result<(), PgLTreeParseError> { - if label.len() <= 256 - && label - .bytes() - .all(|c| c.is_ascii_alphabetic() || c.is_ascii_digit() || c == b'_') - { - self.labels.push(label); - Ok(()) - } else { - Err(PgLTreeParseError::InvalidLtreeLabel) - } + pub fn push(&mut self, label: &str) -> Result<(), PgLTreeParseError> { + self.labels.push(PgLTreeLabel::new(label)?); + Ok(()) } /// pop a label from ltree - pub fn pop(&mut self) -> Option { + pub fn pop(&mut self) -> Option { self.labels.pop() } } impl IntoIterator for PgLTree { - type Item = String; + type Item = PgLTreeLabel; type IntoIter = std::vec::IntoIter; fn into_iter(self) -> Self::IntoIter { @@ -106,7 +137,10 @@ impl FromStr for PgLTree { fn from_str(s: &str) -> Result { Ok(Self { - labels: s.split('.').map(|s| s.to_owned()).collect(), + labels: s + .split('.') + .map(|s| PgLTreeLabel::new(s)) + .collect::, Self::Err>>()?, }) } } @@ -125,7 +159,7 @@ impl Display for PgLTree { } impl Deref for PgLTree { - type Target = [String]; + type Target = [PgLTreeLabel]; fn deref(&self) -> &Self::Target { &self.labels diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index 26524f85d9..19a8165f17 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -168,6 +168,7 @@ mod bytes; mod float; mod int; mod interval; +mod lquery; mod ltree; mod money; mod range; @@ -211,7 +212,12 @@ mod bit_vec; pub use array::PgHasArrayType; pub use interval::PgInterval; +pub use lquery::PgLQuery; +pub use lquery::PgLQueryLevel; +pub use lquery::PgLQueryVariant; +pub use lquery::PgLQueryVariantFlag; pub use ltree::PgLTree; +pub use ltree::PgLTreeLabel; pub use ltree::PgLTreeParseError; pub use money::PgMoney; pub use range::PgRange; diff --git a/sqlx-macros/src/database/postgres.rs b/sqlx-macros/src/database/postgres.rs index 3d51641d34..0f5d59010d 100644 --- a/sqlx-macros/src/database/postgres.rs +++ b/sqlx-macros/src/database/postgres.rs @@ -20,6 +20,8 @@ impl_database_ext! { sqlx::postgres::types::PgLTree, + sqlx::postgres::types::PgLQuery, + #[cfg(feature = "uuid")] sqlx::types::Uuid,