Skip to content

Commit

Permalink
fix(postgres): Add support for domain types description
Browse files Browse the repository at this point in the history
Fix commit updates the `postgres::connection::describe` module to add full support for domain types. Domain types were previously confused with their category which caused invalid oid resolution.

Fixes #110
  • Loading branch information
demurgos committed Mar 16, 2021
1 parent edcc91c commit 1eb3d41
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 8 deletions.
116 changes: 108 additions & 8 deletions sqlx-core/src/postgres/connection/describe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,83 @@ use crate::HashMap;
use futures_core::future::BoxFuture;
use std::fmt::Write;
use std::sync::Arc;
use std::convert::TryFrom;

/// Describes the type of the `pg_type.typtype` column
///
/// See <https://www.postgresql.org/docs/13/catalog-pg-type.html>
enum TypType {
Base,
Composite,
Domain,
Enum,
Pseudo,
Range,
}

impl TryFrom<u8> for TypType {
type Error = ();

fn try_from(t: u8) -> Result<Self, Self::Error> {
let t = match t {
b'b' => Self::Base,
b'c' => Self::Composite,
b'd' => Self::Domain,
b'e' => Self::Enum,
b'p' => Self::Pseudo,
b'r' => Self::Range,
_ => return Err(()),
};
Ok(t)
}
}

/// Describes the type of the `pg_type.typcategory` column
///
/// See <https://www.postgresql.org/docs/13/catalog-pg-type.html#CATALOG-TYPCATEGORY-TABLE>
enum TypCategory {
Array,
Boolean,
Composite,
DateTime,
Enum,
Geometric,
Network,
Numeric,
Pseudo,
Range,
String,
Timespan,
User,
BitString,
Unknown,
}

impl TryFrom<u8> for TypCategory {
type Error = ();

fn try_from(c: u8) -> Result<Self, Self::Error> {
let c = match c {
b'A' => Self::Array,
b'B' => Self::Boolean,
b'C' => Self::Composite,
b'D' => Self::DateTime,
b'E' => Self::Enum,
b'G' => Self::Geometric,
b'I' => Self::Network,
b'N' => Self::Numeric,
b'P' => Self::Pseudo,
b'R' => Self::Range,
b'S' => Self::String,
b'T' => Self::Timespan,
b'U' => Self::User,
b'V' => Self::BitString,
b'X' => Self::Unknown,
_ => return Err(()),
};
Ok(c)
}
}

impl PgConnection {
pub(super) async fn handle_row_description(
Expand Down Expand Up @@ -106,31 +183,37 @@ impl PgConnection {

fn fetch_type_by_oid(&mut self, oid: u32) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
Box::pin(async move {
let (name, category, relation_id, element): (String, i8, u32, u32) = query_as(
"SELECT typname, typcategory, typrelid, typelem FROM pg_catalog.pg_type WHERE oid = $1",

let (name, typ_type, category, relation_id, element, base_type): (String, i8, i8, u32, u32, u32) = query_as(
"SELECT typname, typtype, typcategory, typrelid, typelem, typbasetype FROM pg_catalog.pg_type WHERE oid = $1",
)
.bind(oid)
.fetch_one(&mut *self)
.await?;

match category as u8 {
b'A' => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
let typ_type = TypType::try_from(typ_type as u8);
let category = TypCategory::try_from(category as u8);

match (typ_type, category) {
(Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await,

(Ok(TypType::Base), Ok(TypCategory::Array)) => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Array(self.fetch_type_by_oid(element).await?),
name: name.into(),
oid,
})))),

b'P' => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
(Ok(TypType::Pseudo), Ok(TypCategory::Pseudo)) => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Pseudo,
name: name.into(),
oid,
})))),

b'R' => self.fetch_range_by_oid(oid, name).await,
(Ok(TypType::Range), Ok(TypCategory::Range)) => self.fetch_range_by_oid(oid, name).await,

b'E' => self.fetch_enum_by_oid(oid, name).await,
(Ok(TypType::Enum), Ok(TypCategory::Enum)) => self.fetch_enum_by_oid(oid, name).await,

b'C' => self.fetch_composite_by_oid(oid, relation_id, name).await,
(Ok(TypType::Composite), Ok(TypCategory::Composite)) => self.fetch_composite_by_oid(oid, relation_id, name).await,

_ => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Simple,
Expand Down Expand Up @@ -198,6 +281,23 @@ ORDER BY attnum
})
}

fn fetch_domain_by_oid(
&mut self,
oid: u32,
base_type: u32,
name: String,
) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
Box::pin(async move {
let base_type = self.maybe_fetch_type_info_by_oid(base_type, true).await?;

Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
oid,
name: name.into(),
kind: PgTypeKind::Domain(base_type),
}))))
})
}

fn fetch_range_by_oid(
&mut self,
oid: u32,
Expand Down
97 changes: 97 additions & 0 deletions tests/postgres/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -887,3 +887,100 @@ from (values (null)) vals(val)

Ok(())
}

#[sqlx_macros::test]
async fn it_supports_domain_types_in_composite_domain_types() -> anyhow::Result<()> {
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct MonthId(i16);

impl sqlx::Type<Postgres> for MonthId {
fn type_info() -> sqlx::postgres::PgTypeInfo {
sqlx::postgres::PgTypeInfo::with_name("month_id")
}

fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool {
*ty == Self::type_info()
}
}

impl<'r> sqlx::Decode<'r, Postgres> for MonthId {
fn decode(value: sqlx::postgres::PgValueRef<'r>) -> Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
Ok(Self(<i16 as sqlx::Decode<Postgres>>::decode(value)?))
}
}

impl<'q> sqlx::Encode<'q, Postgres> for MonthId {
fn encode_by_ref(&self, buf: &mut sqlx::postgres::PgArgumentBuffer) -> sqlx::encode::IsNull {
self.0.encode(buf)
}
}

#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct WinterYearMonth {
year: i32,
month: MonthId
}

impl sqlx::Type<Postgres> for WinterYearMonth {
fn type_info() -> sqlx::postgres::PgTypeInfo {
sqlx::postgres::PgTypeInfo::with_name("winter_year_month")
}

fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool {
*ty == Self::type_info()
}
}

impl<'r> sqlx::Decode<'r, Postgres> for WinterYearMonth {
fn decode(value: sqlx::postgres::PgValueRef<'r>) -> Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
let mut decoder = sqlx::postgres::types::PgRecordDecoder::new(value)?;

let year = decoder.try_decode::<i32>()?;
let month = decoder.try_decode::<MonthId>()?;

Ok(Self { year, month })
}
}

impl<'q> sqlx::Encode<'q, Postgres> for WinterYearMonth {
fn encode_by_ref(&self, buf: &mut sqlx::postgres::PgArgumentBuffer) -> sqlx::encode::IsNull {
let mut encoder = sqlx::postgres::types::PgRecordEncoder::new(buf);
encoder.encode(self.year);
encoder.encode(self.month);
encoder.finish();
sqlx::encode::IsNull::No
}
}

let mut conn = new::<Postgres>().await?;

{
let result = sqlx::query("DELETE FROM heating_bills;")
.execute(&mut conn)
.await;

let result = result.unwrap();
assert_eq!(result.rows_affected(), 1);
}

{
let result = sqlx::query("INSERT INTO heating_bills(month, cost) VALUES($1::winter_year_month, 100);")
.bind(WinterYearMonth { year: 2021, month: MonthId(1) })
.execute(&mut conn)
.await;

let result = result.unwrap();
assert_eq!(result.rows_affected(), 1);
}

{
let result = sqlx::query("DELETE FROM heating_bills;")
.execute(&mut conn)
.await;

let result = result.unwrap();
assert_eq!(result.rows_affected(), 1);
}

Ok(())
}

0 comments on commit 1eb3d41

Please sign in to comment.