Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(postgres): Add support for domain types description #1112

Merged
merged 1 commit into from
Apr 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 125 additions & 16 deletions sqlx-core/src/postgres/connection/describe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,86 @@ use crate::query_scalar::{query_scalar, query_scalar_with};
use crate::types::Json;
use crate::HashMap;
use futures_core::future::BoxFuture;
use std::convert::TryFrom;
use std::fmt::Write;
use std::sync::Arc;

/// Describes the type of the `pg_type.typtype` column
///
/// See <https://www.postgresql.org/docs/13/catalog-pg-type.html>
enum TypType {
Base,
Composite,
Domain,
Enum,
Pseudo,
Range,
}

impl TryFrom<u8> for TypType {
type Error = ();

fn try_from(t: u8) -> Result<Self, Self::Error> {
let t = match t {
b'b' => Self::Base,
b'c' => Self::Composite,
b'd' => Self::Domain,
b'e' => Self::Enum,
b'p' => Self::Pseudo,
b'r' => Self::Range,
_ => return Err(()),
};
Ok(t)
}
}

/// Describes the type of the `pg_type.typcategory` column
///
/// See <https://www.postgresql.org/docs/13/catalog-pg-type.html#CATALOG-TYPCATEGORY-TABLE>
enum TypCategory {
Array,
Boolean,
Composite,
DateTime,
Enum,
Geometric,
Network,
Numeric,
Pseudo,
Range,
String,
Timespan,
User,
BitString,
Unknown,
}

impl TryFrom<u8> for TypCategory {
type Error = ();

fn try_from(c: u8) -> Result<Self, Self::Error> {
let c = match c {
b'A' => Self::Array,
b'B' => Self::Boolean,
b'C' => Self::Composite,
b'D' => Self::DateTime,
b'E' => Self::Enum,
b'G' => Self::Geometric,
b'I' => Self::Network,
b'N' => Self::Numeric,
b'P' => Self::Pseudo,
b'R' => Self::Range,
b'S' => Self::String,
b'T' => Self::Timespan,
b'U' => Self::User,
b'V' => Self::BitString,
b'X' => Self::Unknown,
_ => return Err(()),
};
Ok(c)
}
}

impl PgConnection {
pub(super) async fn handle_row_description(
&mut self,
Expand Down Expand Up @@ -106,31 +183,46 @@ impl PgConnection {

fn fetch_type_by_oid(&mut self, oid: u32) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
Box::pin(async move {
let (name, category, relation_id, element): (String, i8, u32, u32) = query_as(
"SELECT typname, typcategory, typrelid, typelem FROM pg_catalog.pg_type WHERE oid = $1",
let (name, typ_type, category, relation_id, element, base_type): (String, i8, i8, u32, u32, u32) = query_as(
"SELECT typname, typtype, typcategory, typrelid, typelem, typbasetype FROM pg_catalog.pg_type WHERE oid = $1",
)
.bind(oid)
.fetch_one(&mut *self)
.await?;

match category as u8 {
b'A' => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Array(self.fetch_type_by_oid(element).await?),
name: name.into(),
oid,
})))),
let typ_type = TypType::try_from(typ_type as u8);
let category = TypCategory::try_from(category as u8);

b'P' => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Pseudo,
name: name.into(),
oid,
})))),
match (typ_type, category) {
(Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await,

(Ok(TypType::Base), Ok(TypCategory::Array)) => {
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Array(self.fetch_type_by_oid(element).await?),
name: name.into(),
oid,
}))))
}

b'R' => self.fetch_range_by_oid(oid, name).await,
(Ok(TypType::Pseudo), Ok(TypCategory::Pseudo)) => {
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Pseudo,
name: name.into(),
oid,
}))))
}

b'E' => self.fetch_enum_by_oid(oid, name).await,
(Ok(TypType::Range), Ok(TypCategory::Range)) => {
self.fetch_range_by_oid(oid, name).await
}

(Ok(TypType::Enum), Ok(TypCategory::Enum)) => {
self.fetch_enum_by_oid(oid, name).await
}

b'C' => self.fetch_composite_by_oid(oid, relation_id, name).await,
(Ok(TypType::Composite), Ok(TypCategory::Composite)) => {
self.fetch_composite_by_oid(oid, relation_id, name).await
}

_ => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Simple,
Expand Down Expand Up @@ -198,6 +290,23 @@ ORDER BY attnum
})
}

fn fetch_domain_by_oid(
&mut self,
oid: u32,
base_type: u32,
name: String,
) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
Box::pin(async move {
let base_type = self.maybe_fetch_type_info_by_oid(base_type, true).await?;

Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
oid,
name: name.into(),
kind: PgTypeKind::Domain(base_type),
}))))
})
}

fn fetch_range_by_oid(
&mut self,
oid: u32,
Expand Down
112 changes: 112 additions & 0 deletions tests/postgres/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -887,3 +887,115 @@ from (values (null)) vals(val)

Ok(())
}

#[sqlx_macros::test]
async fn it_supports_domain_types_in_composite_domain_types() -> anyhow::Result<()> {
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct MonthId(i16);

impl sqlx::Type<Postgres> for MonthId {
fn type_info() -> sqlx::postgres::PgTypeInfo {
sqlx::postgres::PgTypeInfo::with_name("month_id")
}

fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool {
*ty == Self::type_info()
}
}

impl<'r> sqlx::Decode<'r, Postgres> for MonthId {
fn decode(
value: sqlx::postgres::PgValueRef<'r>,
) -> Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
Ok(Self(<i16 as sqlx::Decode<Postgres>>::decode(value)?))
}
}

impl<'q> sqlx::Encode<'q, Postgres> for MonthId {
fn encode_by_ref(
&self,
buf: &mut sqlx::postgres::PgArgumentBuffer,
) -> sqlx::encode::IsNull {
self.0.encode(buf)
}
}

#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct WinterYearMonth {
year: i32,
month: MonthId,
}

impl sqlx::Type<Postgres> for WinterYearMonth {
fn type_info() -> sqlx::postgres::PgTypeInfo {
sqlx::postgres::PgTypeInfo::with_name("winter_year_month")
}

fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool {
*ty == Self::type_info()
}
}

impl<'r> sqlx::Decode<'r, Postgres> for WinterYearMonth {
fn decode(
value: sqlx::postgres::PgValueRef<'r>,
) -> Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
let mut decoder = sqlx::postgres::types::PgRecordDecoder::new(value)?;

let year = decoder.try_decode::<i32>()?;
let month = decoder.try_decode::<MonthId>()?;

Ok(Self { year, month })
}
}

impl<'q> sqlx::Encode<'q, Postgres> for WinterYearMonth {
fn encode_by_ref(
&self,
buf: &mut sqlx::postgres::PgArgumentBuffer,
) -> sqlx::encode::IsNull {
let mut encoder = sqlx::postgres::types::PgRecordEncoder::new(buf);
encoder.encode(self.year);
encoder.encode(self.month);
encoder.finish();
sqlx::encode::IsNull::No
}
}

let mut conn = new::<Postgres>().await?;

{
let result = sqlx::query("DELETE FROM heating_bills;")
.execute(&mut conn)
.await;

let result = result.unwrap();
assert_eq!(result.rows_affected(), 1);
}

{
let result = sqlx::query(
"INSERT INTO heating_bills(month, cost) VALUES($1::winter_year_month, 100);",
)
.bind(WinterYearMonth {
year: 2021,
month: MonthId(1),
})
.execute(&mut conn)
.await;

let result = result.unwrap();
assert_eq!(result.rows_affected(), 1);
}

{
let result = sqlx::query("DELETE FROM heating_bills;")
.execute(&mut conn)
.await;

let result = result.unwrap();
assert_eq!(result.rows_affected(), 1);
}

Ok(())
}
8 changes: 8 additions & 0 deletions tests/postgres/setup.sql
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,11 @@ CREATE TABLE products (
name TEXT,
price NUMERIC CHECK (price > 0)
);

CREATE DOMAIN month_id AS INT2 CHECK (1 <= value AND value <= 12);
CREATE TYPE year_month AS (year INT4, month month_id);
CREATE DOMAIN winter_year_month AS year_month CHECK ((value).month <= 3);
CREATE TABLE heating_bills (
month winter_year_month NOT NULL PRIMARY KEY,
cost INT4 NOT NULL
);