diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 4a1f193948..ffcc5e528d 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -244,3 +244,10 @@ macro_rules! impl_fmt_error { } }; } + +#[allow(unused_macros)] +macro_rules! decode_err ( + ($($args:tt)*) => { + $crate::decode::DecodeError::Message(Box::new(format!($($args)*))) + } +); diff --git a/sqlx-core/src/postgres/types/array.rs b/sqlx-core/src/postgres/types/array.rs new file mode 100644 index 0000000000..c23a11993e --- /dev/null +++ b/sqlx-core/src/postgres/types/array.rs @@ -0,0 +1,213 @@ +/// Encoding and decoding of Postgres arrays. Documentation of the byte format can be found [here](https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/include/utils/array.h;h=7f7e744cb12bc872f628f90dad99dfdf074eb314;hb=master#l6) +use crate::decode::Decode; +use crate::decode::DecodeError; +use crate::encode::Encode; +use crate::io::{Buf, BufMut}; +use crate::postgres::database::Postgres; +use crate::types::HasSqlType; +use std::marker::PhantomData; + +impl Encode for [T] +where + T: Encode, + Postgres: HasSqlType, +{ + fn encode(&self, buf: &mut Vec) { + let mut encoder = ArrayEncoder::new(buf); + for item in self { + encoder.push(item); + } + } +} +impl Encode for Vec +where + [T]: Encode, + Postgres: HasSqlType, +{ + fn encode(&self, buf: &mut Vec) { + self.as_slice().encode(buf) + } +} + +impl Decode for Vec +where + T: Decode, + Postgres: HasSqlType, +{ + fn decode(buf: &[u8]) -> Result { + let decoder = ArrayDecoder::::new(buf)?; + decoder.collect() + } +} + +type Order = byteorder::BigEndian; + +struct ArrayDecoder<'a, T> +where + T: Decode, + Postgres: HasSqlType, +{ + left: usize, + did_error: bool, + + buf: &'a [u8], + + phantom: PhantomData, +} + +impl ArrayDecoder<'_, T> +where + T: Decode, + Postgres: HasSqlType, +{ + fn new(mut buf: &[u8]) -> Result, DecodeError> { + let ndim = buf.get_i32::()?; + let dataoffset = buf.get_i32::()?; + let elemtype = buf.get_i32::()?; + + if ndim == 0 { + return Ok(ArrayDecoder { + left: 0, + did_error: false, + buf, + phantom: PhantomData, + }); + } + + if ndim != 1 { + return Err(decode_err!( + "only arrays of dimension 1 is supported, found array of dimension {}", + ndim + )); + } + + let dimensions = buf.get_i32::()?; + let lower_bnds = buf.get_i32::()?; + + if dataoffset != 0 { + // arrays with [null bitmap] is not supported + return Err(DecodeError::UnexpectedNull); + } + if elemtype != >::type_info().id.0 as i32 { + return Err(decode_err!("mismatched array element type")); + } + if lower_bnds != 1 { + return Err(decode_err!( + "expected lower_bnds of array to be 1, but found {}", + lower_bnds + )); + } + + Ok(ArrayDecoder { + left: dimensions as usize, + did_error: false, + buf, + + phantom: PhantomData, + }) + } + + /// Decodes the next element without worring how many are left, or if it previously errored + fn decode_next_element(&mut self) -> Result { + let len = self.buf.get_i32::()?; + let bytes = self.buf.get_bytes(len as usize)?; + Decode::decode(bytes) + } +} + +impl Iterator for ArrayDecoder<'_, T> +where + T: Decode, + Postgres: HasSqlType, +{ + type Item = Result; + + fn next(&mut self) -> Option> { + if self.did_error || self.left == 0 { + return None; + } + + self.left -= 1; + + let decoded = self.decode_next_element(); + self.did_error = decoded.is_err(); + Some(decoded) + } +} + +struct ArrayEncoder<'a, T> +where + T: Encode, + Postgres: HasSqlType, +{ + count: usize, + len_start_index: usize, + buf: &'a mut Vec, + + phantom: PhantomData, +} + +impl ArrayEncoder<'_, T> +where + T: Encode, + Postgres: HasSqlType, +{ + fn new(buf: &mut Vec) -> ArrayEncoder { + let ty = >::type_info(); + + // ndim + buf.put_i32::(1); + // dataoffset + buf.put_i32::(0); + // elemtype + buf.put_i32::(ty.id.0 as i32); + let len_start_index = buf.len(); + // dimensions + buf.put_i32::(0); + // lower_bnds + buf.put_i32::(1); + + ArrayEncoder { + count: 0, + len_start_index, + buf, + + phantom: PhantomData, + } + } + fn push(&mut self, item: &T) { + // Allocate space for the length of the encoded elemement up front + let el_len_index = self.buf.len(); + self.buf.put_i32::(0); + + // Allocate and encode the element it self + let el_start = self.buf.len(); + Encode::encode(item, self.buf); + let el_end = self.buf.len(); + + // Now we know the actual length of the encoded element + let el_len = el_end - el_start; + + // And we can now go back and update the length + self.buf[el_len_index..el_start].copy_from_slice(&(el_len as i32).to_be_bytes()); + + self.count += 1; + } + fn update_len(&mut self) { + const I32_SIZE: usize = std::mem::size_of::(); + + let size_bytes = (self.count as i32).to_be_bytes(); + + self.buf[self.len_start_index..self.len_start_index + I32_SIZE] + .copy_from_slice(&size_bytes); + } +} +impl Drop for ArrayEncoder<'_, T> +where + T: Encode, + Postgres: HasSqlType, +{ + fn drop(&mut self) { + self.update_len(); + } +} diff --git a/sqlx-core/src/postgres/types/bool.rs b/sqlx-core/src/postgres/types/bool.rs index 56e220cf1d..a1a4c8e10f 100644 --- a/sqlx-core/src/postgres/types/bool.rs +++ b/sqlx-core/src/postgres/types/bool.rs @@ -16,6 +16,11 @@ impl HasSqlType<[bool]> for Postgres { PgTypeInfo::new(TypeId::ARRAY_BOOL) } } +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} impl Encode for bool { fn encode(&self, buf: &mut Vec) { diff --git a/sqlx-core/src/postgres/types/bytes.rs b/sqlx-core/src/postgres/types/bytes.rs index 0c06b9085e..d484cc19d1 100644 --- a/sqlx-core/src/postgres/types/bytes.rs +++ b/sqlx-core/src/postgres/types/bytes.rs @@ -17,6 +17,12 @@ impl HasSqlType<[&'_ [u8]]> for Postgres { } } +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} + // TODO: Do we need the [HasSqlType] here on the Vec? impl HasSqlType> for Postgres { fn type_info() -> PgTypeInfo { diff --git a/sqlx-core/src/postgres/types/chrono.rs b/sqlx-core/src/postgres/types/chrono.rs index d73d1d8927..4014f4936b 100644 --- a/sqlx-core/src/postgres/types/chrono.rs +++ b/sqlx-core/src/postgres/types/chrono.rs @@ -55,6 +55,24 @@ impl HasSqlType<[NaiveDateTime]> for Postgres { } } +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} + +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} + +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} + impl HasSqlType<[DateTime]> for Postgres where Tz: TimeZone, diff --git a/sqlx-core/src/postgres/types/float.rs b/sqlx-core/src/postgres/types/float.rs index 48539d989b..dff7a7edf5 100644 --- a/sqlx-core/src/postgres/types/float.rs +++ b/sqlx-core/src/postgres/types/float.rs @@ -16,6 +16,11 @@ impl HasSqlType<[f32]> for Postgres { PgTypeInfo::new(TypeId::ARRAY_FLOAT4) } } +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} impl Encode for f32 { fn encode(&self, buf: &mut Vec) { @@ -42,6 +47,11 @@ impl HasSqlType<[f64]> for Postgres { PgTypeInfo::new(TypeId::ARRAY_FLOAT8) } } +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} impl Encode for f64 { fn encode(&self, buf: &mut Vec) { diff --git a/sqlx-core/src/postgres/types/int.rs b/sqlx-core/src/postgres/types/int.rs index b2755883b1..156ebcbf4b 100644 --- a/sqlx-core/src/postgres/types/int.rs +++ b/sqlx-core/src/postgres/types/int.rs @@ -18,6 +18,11 @@ impl HasSqlType<[i16]> for Postgres { PgTypeInfo::new(TypeId::ARRAY_INT2) } } +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} impl Encode for i16 { fn encode(&self, buf: &mut Vec) { @@ -42,6 +47,11 @@ impl HasSqlType<[i32]> for Postgres { PgTypeInfo::new(TypeId::ARRAY_INT4) } } +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} impl Encode for i32 { fn encode(&self, buf: &mut Vec) { @@ -66,6 +76,11 @@ impl HasSqlType<[i64]> for Postgres { PgTypeInfo::new(TypeId::ARRAY_INT8) } } +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} impl Encode for i64 { fn encode(&self, buf: &mut Vec) { diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index 28404d8b26..36dfa23fc1 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -1,3 +1,4 @@ +mod array; mod bool; mod bytes; mod float; diff --git a/sqlx-core/src/postgres/types/str.rs b/sqlx-core/src/postgres/types/str.rs index 56f74ff9db..18899ee369 100644 --- a/sqlx-core/src/postgres/types/str.rs +++ b/sqlx-core/src/postgres/types/str.rs @@ -18,6 +18,11 @@ impl HasSqlType<[&'_ str]> for Postgres { PgTypeInfo::new(TypeId::ARRAY_TEXT) } } +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} // TODO: Do we need [HasSqlType] on String here? impl HasSqlType for Postgres { @@ -25,6 +30,16 @@ impl HasSqlType for Postgres { >::type_info() } } +impl HasSqlType<[String]> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >>::type_info() + } +} impl Encode for str { fn encode(&self, buf: &mut Vec) { diff --git a/sqlx-core/src/postgres/types/uuid.rs b/sqlx-core/src/postgres/types/uuid.rs index fecc408cdd..8c65fe37e0 100644 --- a/sqlx-core/src/postgres/types/uuid.rs +++ b/sqlx-core/src/postgres/types/uuid.rs @@ -19,6 +19,12 @@ impl HasSqlType<[Uuid]> for Postgres { } } +impl HasSqlType> for Postgres { + fn type_info() -> PgTypeInfo { + >::type_info() + } +} + impl Encode for Uuid { fn encode(&self, buf: &mut Vec) { buf.extend_from_slice(self.as_bytes()); diff --git a/sqlx-macros/src/database/postgres.rs b/sqlx-macros/src/database/postgres.rs index dcace458bb..6ea3906771 100644 --- a/sqlx-macros/src/database/postgres.rs +++ b/sqlx-macros/src/database/postgres.rs @@ -25,6 +25,15 @@ impl_database_ext! { #[cfg(feature = "chrono")] sqlx::types::chrono::DateTime | sqlx::types::chrono::DateTime<_>, + + // Arrays + + Vec | &[String], + Vec | &[i16], + Vec | &[i32], + Vec | &[i64], + Vec | &[f32], + Vec | &[f64], }, ParamChecking::Strong } diff --git a/tests/postgres-macros.rs b/tests/postgres-macros.rs index 82378a5abe..a90374325f 100644 --- a/tests/postgres-macros.rs +++ b/tests/postgres-macros.rs @@ -176,6 +176,32 @@ async fn test_many_args() -> anyhow::Result<()> { Ok(()) } +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn test_array_from_slice() -> anyhow::Result<()> { + let mut conn = connect().await?; + + let list: &[i32] = &[1, 2, 3, 4i32]; + + let result = sqlx::query!("SELECT $1::int[] as my_array", list) + .fetch_one(&mut conn) + .await?; + + assert_eq!(result.my_array, vec![1, 2, 3, 4]); + + println!("result ID: {:?}", result.my_array); + + let account = sqlx::query!("SELECT ARRAY[4,3,2,1] as my_array") + .fetch_one(&mut conn) + .await?; + + assert_eq!(account.my_array, vec![4, 3, 2, 1]); + + println!("account ID: {:?}", account.my_array); + + Ok(()) +} + async fn connect() -> anyhow::Result { let _ = dotenv::dotenv(); let _ = env_logger::try_init(); diff --git a/tests/postgres-types.rs b/tests/postgres-types.rs index 177c8e0ae9..1d246ed4cb 100644 --- a/tests/postgres-types.rs +++ b/tests/postgres-types.rs @@ -37,6 +37,12 @@ test!(postgres_double: f64: "939399419.1225182::double precision" == 939399419.1 test!(postgres_text: String: "'this is foo'" == "this is foo", "''" == ""); +test!(postgres_int_vec: Vec: "ARRAY[1, 2, 3]::int[]" == vec![1, 2, 3i32], "ARRAY[3, 292, 15, 2, 3]::int[]" == vec![3, 292, 15, 2, 3], "ARRAY[7, 6, 5, 4, 3, 2, 1]::int[]" == vec![7, 6, 5, 4, 3, 2, 1], "ARRAY[]::int[]" == vec![] as Vec); +test!(postgres_string_vec: Vec: "ARRAY['Hello', 'world', 'friend']::text[]" == vec!["Hello", "world", "friend"]); +test!(postgres_bool_vec: Vec: "ARRAY[true, true, false, true]::bool[]" == vec![true, true, false, true]); +test!(postgres_real_vec: Vec: "ARRAY[0.0, 1.0, 3.14, 1.234, -0.002, 100000.0]::real[]" == vec![0.0, 1.0, 3.14, 1.234, -0.002, 100000.0_f32]); +test!(postgres_double_vec: Vec: "ARRAY[0.0, 1.0, 3.14, 1.234, -0.002, 100000.0]::double precision[]" == vec![0.0, 1.0, 3.14, 1.234, -0.002, 100000.0_f64]); + #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn postgres_bytes() -> anyhow::Result<()> {