Skip to content

Commit

Permalink
fix(postgres): add missing type resolution for arrays by name
Browse files Browse the repository at this point in the history
  • Loading branch information
abonander committed Jul 8, 2024
1 parent efbf572 commit 16e3f10
Show file tree
Hide file tree
Showing 19 changed files with 333 additions and 84 deletions.
28 changes: 14 additions & 14 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ bit-vec = "0.6.3"
chrono = { version = "0.4.22", default-features = false }
ipnetwork = "0.20.0"
mac_address = "1.1.5"
rust_decimal = "1.26.1"
rust_decimal = { version = "1.26.1", default-features = false, features = ["std"] }
time = { version = "0.3.36", features = ["formatting", "parsing", "macros"] }
uuid = "1.1.2"

Expand Down
2 changes: 1 addition & 1 deletion sqlx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ uuid = { workspace = true, optional = true }

async-io = { version = "1.9.0", optional = true }
paste = "1.0.6"
ahash = "0.8.7"
atoi = "2.0"

bytes = "1.1.0"
Expand Down Expand Up @@ -88,6 +87,7 @@ bstr = { version = "1.0", default-features = false, features = ["std"], optional
hashlink = "0.9.0"
indexmap = "2.0"
event-listener = "5.2.0"
hashbrown = "0.14.5"

[dev-dependencies]
sqlx = { workspace = true, features = ["postgres", "sqlite", "mysql", "migrate", "macros", "time", "uuid"] }
Expand Down
14 changes: 14 additions & 0 deletions sqlx-core/src/ext/ustr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ impl UStr {
pub fn new(s: &str) -> Self {
UStr::Shared(Arc::from(s.to_owned()))
}

/// Apply [str::strip_prefix], without copying if possible.
pub fn strip_prefix(this: &Self, prefix: &str) -> Option<Self> {
match this {
UStr::Static(s) => s.strip_prefix(prefix).map(Self::Static),
UStr::Shared(s) => s.strip_prefix(prefix).map(|s| Self::Shared(s.into())),
}
}
}

impl Deref for UStr {
Expand Down Expand Up @@ -60,6 +68,12 @@ impl From<&'static str> for UStr {
}
}

impl<'a> From<&'a UStr> for UStr {
fn from(value: &'a UStr) -> Self {
value.clone()
}
}

impl From<String> for UStr {
#[inline]
fn from(s: String) -> Self {
Expand Down
7 changes: 2 additions & 5 deletions sqlx-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,15 @@ pub mod testing;

pub use error::{Error, Result};

/// sqlx uses ahash for increased performance, at the cost of reduced DoS resistance.
pub use ahash::AHashMap as HashMap;
pub use either::Either;
pub use hashbrown::{hash_map, HashMap};
pub use indexmap::IndexMap;
pub use percent_encoding;
pub use smallvec::SmallVec;
pub use url::{self, Url};

pub use bytes;

//type HashMap<K, V> = std::collections::HashMap<K, V, ahash::RandomState>;

/// Helper module to get drivers compiling again that used to be in this crate,
/// to avoid having to replace tons of `use crate::<...>` imports.
///
Expand All @@ -119,6 +116,6 @@ pub mod driver_prelude {
};

pub use crate::error::{Error, Result};
pub use crate::HashMap;
pub use crate::{hash_map, HashMap};
pub use either::Either;
}
10 changes: 10 additions & 0 deletions sqlx-core/src/type_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@ pub trait TypeInfo: Debug + Display + Clone + PartialEq<Self> + Send + Sync {
/// should be a rough approximation of how they are written in SQL in the given database.
fn name(&self) -> &str;

/// Return `true` if `self` and `other` represent mutually compatible types.
///
/// Defaults to `self == other`.
fn type_compatible(&self, other: &Self) -> bool
where
Self: Sized,
{
self == other
}

#[doc(hidden)]
fn is_void(&self) -> bool {
false
Expand Down
4 changes: 3 additions & 1 deletion sqlx-core/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,10 @@ pub trait Type<DB: Database> {
///
/// When binding arguments with `query!` or `query_as!`, this method is consulted to determine
/// if the Rust type is acceptable.
///
/// Defaults to checking [`TypeInfo::type_compatible()`].
fn compatible(ty: &DB::TypeInfo) -> bool {
*ty == Self::type_info()
Self::type_info().type_compatible(ty)
}
}

Expand Down
52 changes: 36 additions & 16 deletions sqlx-macros-core/src/derives/type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,42 @@ use syn::{
pub fn expand_derive_type(input: &DeriveInput) -> syn::Result<TokenStream> {
let attrs = parse_container_attributes(&input.attrs)?;
match &input.data {
// Newtype structs:
// struct Foo(i32);
Data::Struct(DataStruct {
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
..
}) if unnamed.len() == 1 => {
expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap())
}) => {
if unnamed.len() == 1 {
expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap())
} else {
Err(syn::Error::new_spanned(
input,
"structs with zero or more than one unnamed field are not supported",
))
}
}
Data::Enum(DataEnum { variants, .. }) => match attrs.repr {
Some(_) => expand_derive_has_sql_type_weak_enum(input, variants),
None => expand_derive_has_sql_type_strong_enum(input, variants),
},
// Record types
// struct Foo { foo: i32, bar: String }
Data::Struct(DataStruct {
fields: Fields::Named(FieldsNamed { named, .. }),
..
}) => expand_derive_has_sql_type_struct(input, named),
Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")),
Data::Struct(DataStruct {
fields: Fields::Unnamed(..),
..
}) => Err(syn::Error::new_spanned(
input,
"structs with zero or more than one unnamed field are not supported",
)),
Data::Struct(DataStruct {
fields: Fields::Unit,
..
}) => Err(syn::Error::new_spanned(
input,
"unit structs are not supported",
)),

Data::Enum(DataEnum { variants, .. }) => match attrs.repr {
// Enums that encode to/from integers (weak enums)
Some(_) => expand_derive_has_sql_type_weak_enum(input, variants),
// Enums that decode to/from strings (strong enums)
None => expand_derive_has_sql_type_strong_enum(input, variants),
},
Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")),
}
}

Expand Down Expand Up @@ -148,9 +155,10 @@ fn expand_derive_has_sql_type_weak_enum(

if cfg!(feature = "postgres") && !attrs.no_pg_array {
ts.extend(quote!(
#[automatically_derived]
impl ::sqlx::postgres::PgHasArrayType for #ident {
fn array_type_info() -> ::sqlx::postgres::PgTypeInfo {
<#ident as ::sqlx::postgres::PgHasArrayType>::array_type_info()
<#repr as ::sqlx::postgres::PgHasArrayType>::array_type_info()
}
}
));
Expand Down Expand Up @@ -197,9 +205,10 @@ fn expand_derive_has_sql_type_strong_enum(

if !attributes.no_pg_array {
tts.extend(quote!(
#[automatically_derived]
impl ::sqlx::postgres::PgHasArrayType for #ident {
fn array_type_info() -> ::sqlx::postgres::PgTypeInfo {
<#ident as ::sqlx::postgres::PgHasArrayType>::array_type_info()
::sqlx::postgres::PgTypeInfo::array_of(#ty_name)
}
}
));
Expand Down Expand Up @@ -244,6 +253,17 @@ fn expand_derive_has_sql_type_struct(
}
}
));

if !attributes.no_pg_array {
tts.extend(quote!(
#[automatically_derived]
impl ::sqlx::postgres::PgHasArrayType for #ident {
fn array_type_info() -> ::sqlx::postgres::PgTypeInfo {
::sqlx::postgres::PgTypeInfo::array_of(#ty_name)
}
}
));
}
}

Ok(tts)
Expand Down
3 changes: 3 additions & 0 deletions sqlx-postgres/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,8 @@ workspace = true
# We use JSON in the driver implementation itself so there's no reason not to enable it here.
features = ["json"]

[dev-dependencies]
sqlx.workspace = true

[target.'cfg(target_os = "windows")'.dependencies]
etcetera = "0.8.0"
30 changes: 26 additions & 4 deletions sqlx-postgres/src/arguments.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use std::fmt::{self, Write};
use std::ops::{Deref, DerefMut};
use std::sync::Arc;

use crate::encode::{Encode, IsNull};
use crate::error::Error;
use crate::ext::ustr::UStr;
use crate::types::Type;
use crate::{PgConnection, PgTypeInfo, Postgres};

use crate::type_info::PgArrayOf;
pub(crate) use sqlx_core::arguments::Arguments;
use sqlx_core::error::BoxDynError;

Expand Down Expand Up @@ -41,7 +43,12 @@ pub struct PgArgumentBuffer {
// This is done for Records and Arrays as the OID is needed well before we are in an async
// function and can just ask postgres.
//
type_holes: Vec<(usize, UStr)>, // Vec<{ offset, type_name }>
type_holes: Vec<(usize, HoleKind)>, // Vec<{ offset, type_name }>
}

enum HoleKind {
Type { name: UStr },
Array(Arc<PgArrayOf>),
}

struct Patch {
Expand Down Expand Up @@ -106,8 +113,11 @@ impl PgArguments {
(patch.callback)(buf, ty);
}

for (offset, name) in type_holes {
let oid = conn.fetch_type_id_by_name(name).await?;
for (offset, kind) in type_holes {
let oid = match kind {
HoleKind::Type { name } => conn.fetch_type_id_by_name(name).await?,
HoleKind::Array(array) => conn.fetch_array_type_id(array).await?,
};
buffer[*offset..(*offset + 4)].copy_from_slice(&oid.0.to_be_bytes());
}

Expand Down Expand Up @@ -186,7 +196,19 @@ impl PgArgumentBuffer {
let offset = self.len();

self.extend_from_slice(&0_u32.to_be_bytes());
self.type_holes.push((offset, type_name.clone()));
self.type_holes.push((
offset,
HoleKind::Type {
name: type_name.clone(),
},
));
}

pub(crate) fn patch_array_type(&mut self, array: Arc<PgArrayOf>) {
let offset = self.len();

self.extend_from_slice(&0_u32.to_be_bytes());
self.type_holes.push((offset, HoleKind::Array(array)));
}

fn snapshot(&self) -> PgArgumentBufferSnapshot {
Expand Down
Loading

0 comments on commit 16e3f10

Please sign in to comment.