From 1eb3d41dd6dcd3aa44b6be68cd7b730e7a3928b2 Mon Sep 17 00:00:00 2001 From: Charles Samborski Date: Tue, 16 Mar 2021 18:39:52 +0100 Subject: [PATCH] fix(postgres): Add support for domain types description Fix commit updates the `postgres::connection::describe` module to add full support for domain types. Domain types were previously confused with their category which caused invalid oid resolution. Fixes launchbadge/sqlx#110 --- sqlx-core/src/postgres/connection/describe.rs | 116 ++++++++++++++++-- tests/postgres/postgres.rs | 97 +++++++++++++++ 2 files changed, 205 insertions(+), 8 deletions(-) diff --git a/sqlx-core/src/postgres/connection/describe.rs b/sqlx-core/src/postgres/connection/describe.rs index f9e2ebf300..d203d4e72e 100644 --- a/sqlx-core/src/postgres/connection/describe.rs +++ b/sqlx-core/src/postgres/connection/describe.rs @@ -11,6 +11,83 @@ use crate::HashMap; use futures_core::future::BoxFuture; use std::fmt::Write; use std::sync::Arc; +use std::convert::TryFrom; + +/// Describes the type of the `pg_type.typtype` column +/// +/// See +enum TypType { + Base, + Composite, + Domain, + Enum, + Pseudo, + Range, +} + +impl TryFrom for TypType { + type Error = (); + + fn try_from(t: u8) -> Result { + let t = match t { + b'b' => Self::Base, + b'c' => Self::Composite, + b'd' => Self::Domain, + b'e' => Self::Enum, + b'p' => Self::Pseudo, + b'r' => Self::Range, + _ => return Err(()), + }; + Ok(t) + } +} + +/// Describes the type of the `pg_type.typcategory` column +/// +/// See +enum TypCategory { + Array, + Boolean, + Composite, + DateTime, + Enum, + Geometric, + Network, + Numeric, + Pseudo, + Range, + String, + Timespan, + User, + BitString, + Unknown, +} + +impl TryFrom for TypCategory { + type Error = (); + + fn try_from(c: u8) -> Result { + let c = match c { + b'A' => Self::Array, + b'B' => Self::Boolean, + b'C' => Self::Composite, + b'D' => Self::DateTime, + b'E' => Self::Enum, + b'G' => Self::Geometric, + b'I' => Self::Network, + b'N' => Self::Numeric, + b'P' => Self::Pseudo, + b'R' => Self::Range, + b'S' => Self::String, + b'T' => Self::Timespan, + b'U' => Self::User, + b'V' => Self::BitString, + b'X' => Self::Unknown, + _ => return Err(()), + }; + Ok(c) + } +} impl PgConnection { pub(super) async fn handle_row_description( @@ -106,31 +183,37 @@ impl PgConnection { fn fetch_type_by_oid(&mut self, oid: u32) -> BoxFuture<'_, Result> { Box::pin(async move { - let (name, category, relation_id, element): (String, i8, u32, u32) = query_as( - "SELECT typname, typcategory, typrelid, typelem FROM pg_catalog.pg_type WHERE oid = $1", + + let (name, typ_type, category, relation_id, element, base_type): (String, i8, i8, u32, u32, u32) = query_as( + "SELECT typname, typtype, typcategory, typrelid, typelem, typbasetype FROM pg_catalog.pg_type WHERE oid = $1", ) .bind(oid) .fetch_one(&mut *self) .await?; - match category as u8 { - b'A' => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { + let typ_type = TypType::try_from(typ_type as u8); + let category = TypCategory::try_from(category as u8); + + match (typ_type, category) { + (Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await, + + (Ok(TypType::Base), Ok(TypCategory::Array)) => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { kind: PgTypeKind::Array(self.fetch_type_by_oid(element).await?), name: name.into(), oid, })))), - b'P' => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { + (Ok(TypType::Pseudo), Ok(TypCategory::Pseudo)) => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { kind: PgTypeKind::Pseudo, name: name.into(), oid, })))), - b'R' => self.fetch_range_by_oid(oid, name).await, + (Ok(TypType::Range), Ok(TypCategory::Range)) => self.fetch_range_by_oid(oid, name).await, - b'E' => self.fetch_enum_by_oid(oid, name).await, + (Ok(TypType::Enum), Ok(TypCategory::Enum)) => self.fetch_enum_by_oid(oid, name).await, - b'C' => self.fetch_composite_by_oid(oid, relation_id, name).await, + (Ok(TypType::Composite), Ok(TypCategory::Composite)) => self.fetch_composite_by_oid(oid, relation_id, name).await, _ => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { kind: PgTypeKind::Simple, @@ -198,6 +281,23 @@ ORDER BY attnum }) } + fn fetch_domain_by_oid( + &mut self, + oid: u32, + base_type: u32, + name: String, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let base_type = self.maybe_fetch_type_info_by_oid(base_type, true).await?; + + Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { + oid, + name: name.into(), + kind: PgTypeKind::Domain(base_type), + })))) + }) + } + fn fetch_range_by_oid( &mut self, oid: u32, diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index dee9062d8e..fcd618e468 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -887,3 +887,100 @@ from (values (null)) vals(val) Ok(()) } + +#[sqlx_macros::test] +async fn it_supports_domain_types_in_composite_domain_types() -> anyhow::Result<()> { + #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] + struct MonthId(i16); + + impl sqlx::Type for MonthId { + fn type_info() -> sqlx::postgres::PgTypeInfo { + sqlx::postgres::PgTypeInfo::with_name("month_id") + } + + fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool { + *ty == Self::type_info() + } + } + + impl<'r> sqlx::Decode<'r, Postgres> for MonthId { + fn decode(value: sqlx::postgres::PgValueRef<'r>) -> Result> { + Ok(Self(>::decode(value)?)) + } + } + + impl<'q> sqlx::Encode<'q, Postgres> for MonthId { + fn encode_by_ref(&self, buf: &mut sqlx::postgres::PgArgumentBuffer) -> sqlx::encode::IsNull { + self.0.encode(buf) + } + } + + #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] + struct WinterYearMonth { + year: i32, + month: MonthId + } + + impl sqlx::Type for WinterYearMonth { + fn type_info() -> sqlx::postgres::PgTypeInfo { + sqlx::postgres::PgTypeInfo::with_name("winter_year_month") + } + + fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool { + *ty == Self::type_info() + } + } + + impl<'r> sqlx::Decode<'r, Postgres> for WinterYearMonth { + fn decode(value: sqlx::postgres::PgValueRef<'r>) -> Result> { + let mut decoder = sqlx::postgres::types::PgRecordDecoder::new(value)?; + + let year = decoder.try_decode::()?; + let month = decoder.try_decode::()?; + + Ok(Self { year, month }) + } + } + + impl<'q> sqlx::Encode<'q, Postgres> for WinterYearMonth { + fn encode_by_ref(&self, buf: &mut sqlx::postgres::PgArgumentBuffer) -> sqlx::encode::IsNull { + let mut encoder = sqlx::postgres::types::PgRecordEncoder::new(buf); + encoder.encode(self.year); + encoder.encode(self.month); + encoder.finish(); + sqlx::encode::IsNull::No + } + } + + let mut conn = new::().await?; + + { + let result = sqlx::query("DELETE FROM heating_bills;") + .execute(&mut conn) + .await; + + let result = result.unwrap(); + assert_eq!(result.rows_affected(), 1); + } + + { + let result = sqlx::query("INSERT INTO heating_bills(month, cost) VALUES($1::winter_year_month, 100);") + .bind(WinterYearMonth { year: 2021, month: MonthId(1) }) + .execute(&mut conn) + .await; + + let result = result.unwrap(); + assert_eq!(result.rows_affected(), 1); + } + + { + let result = sqlx::query("DELETE FROM heating_bills;") + .execute(&mut conn) + .await; + + let result = result.unwrap(); + assert_eq!(result.rows_affected(), 1); + } + + Ok(()) +}