diff --git a/sqlx-core/src/postgres/types/array.rs b/sqlx-core/src/postgres/types/array.rs index 761b955dd6..bd07861173 100644 --- a/sqlx-core/src/postgres/types/array.rs +++ b/sqlx-core/src/postgres/types/array.rs @@ -55,6 +55,19 @@ where } } +impl Type for [T; N] +where + T: PgHasArrayType, +{ + fn type_info() -> PgTypeInfo { + T::array_type_info() + } + + fn compatible(ty: &PgTypeInfo) -> bool { + T::array_compatible(ty) + } +} + impl<'q, T> Encode<'q, Postgres> for Vec where for<'a> &'a [T]: Encode<'q, Postgres>, @@ -66,6 +79,16 @@ where } } +impl<'q, T, const N: usize> Encode<'q, Postgres> for [T; N] +where + for<'a> &'a [T]: Encode<'q, Postgres>, + T: Encode<'q, Postgres>, +{ + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + self.as_slice().encode_by_ref(buf) + } +} + impl<'q, T> Encode<'q, Postgres> for &'_ [T] where T: Encode<'q, Postgres> + Type, @@ -100,6 +123,19 @@ where } } +impl<'r, T, const N: usize> Decode<'r, Postgres> for [T; N] +where + T: for<'a> Decode<'a, Postgres> + Type, +{ + fn decode(value: PgValueRef<'r>) -> Result { + // This could be done more efficiently by refactoring the Vec decoding below so that it can + // be used for arrays and Vec. + let vec: Vec = Decode::decode(value)?; + let array: [T; N] = vec.try_into().map_err(|_| "wrong number of elements")?; + Ok(array) + } +} + impl<'r, T> Decode<'r, Postgres> for Vec where T: for<'a> Decode<'a, Postgres> + Type, diff --git a/sqlx-core/src/postgres/types/bytes.rs b/sqlx-core/src/postgres/types/bytes.rs index aadc2ef3be..6325c4376d 100644 --- a/sqlx-core/src/postgres/types/bytes.rs +++ b/sqlx-core/src/postgres/types/bytes.rs @@ -24,6 +24,12 @@ impl PgHasArrayType for Vec { } } +impl PgHasArrayType for [u8; N] { + fn array_type_info() -> PgTypeInfo { + <[&[u8]] as Type>::type_info() + } +} + impl Encode<'_, Postgres> for &'_ [u8] { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { buf.extend_from_slice(self); @@ -38,6 +44,12 @@ impl Encode<'_, Postgres> for Vec { } } +impl Encode<'_, Postgres> for [u8; N] { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + <&[u8] as Encode>::encode(self.as_slice(), buf) + } +} + impl<'r> Decode<'r, Postgres> for &'r [u8] { fn decode(value: PgValueRef<'r>) -> Result { match value.format() { @@ -49,18 +61,33 @@ impl<'r> Decode<'r, Postgres> for &'r [u8] { } } +fn text_hex_decode_input(value: PgValueRef<'_>) -> Result<&[u8], BoxDynError> { + // BYTEA is formatted as \x followed by hex characters + value + .as_bytes()? + .strip_prefix(b"\\x") + .ok_or("text does not start with \\x") + .map_err(Into::into) +} + impl Decode<'_, Postgres> for Vec { fn decode(value: PgValueRef<'_>) -> Result { Ok(match value.format() { PgValueFormat::Binary => value.as_bytes()?.to_owned(), - PgValueFormat::Text => { - // BYTEA is formatted as \x followed by hex characters - let text = value - .as_bytes()? - .strip_prefix(b"\\x") - .ok_or("text does not start with \\x")?; - hex::decode(text)? - } + PgValueFormat::Text => hex::decode(text_hex_decode_input(value)?)?, }) } } + +impl Decode<'_, Postgres> for [u8; N] { + fn decode(value: PgValueRef<'_>) -> Result { + let mut bytes = [0u8; N]; + match value.format() { + PgValueFormat::Binary => { + bytes = value.as_bytes()?.try_into()?; + } + PgValueFormat::Text => hex::decode_to_slice(text_hex_decode_input(value)?, &mut bytes)?, + }; + Ok(bytes) + } +} diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index c804d00d3d..c6e6a4db3c 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -15,6 +15,10 @@ test_type!(null_vec>>(Postgres, "array[10,NULL,50]::int2[]" == vec![Some(10_i16), None, Some(50)], )); +test_type!(null_array<[Option; 3]>(Postgres, + "array[10,NULL,50]::int2[]" == vec![Some(10_i16), None, Some(50)], +)); + test_type!(bool(Postgres, "false::boolean" == false, "true::boolean" == true @@ -24,6 +28,10 @@ test_type!(bool_vec>(Postgres, "array[true,false,true]::bool[]" == vec![true, false, true], )); +test_type!(bool_array<[bool; 3]>(Postgres, + "array[true,false,true]::bool[]" == vec![true, false, true], +)); + test_type!(byte_vec>(Postgres, "E'\\\\xDEADBEEF'::bytea" == vec![0xDE_u8, 0xAD, 0xBE, 0xEF], @@ -41,6 +49,14 @@ test_prepared_type!(byte_slice<&[u8]>(Postgres, == &[0_u8, 0, 0, 0, 0x52][..] )); +test_type!(byte_array_empty<[u8; 0]>(Postgres, + "E'\\\\x'::bytea" == [0_u8; 0], +)); + +test_type!(byte_array<[u8; 4]>(Postgres, + "E'\\\\xDEADBEEF'::bytea" == [0xDE_u8, 0xAD, 0xBE, 0xEF], +)); + test_type!(str<&str>(Postgres, "'this is foo'" == "this is foo", "''" == "", @@ -64,6 +80,10 @@ test_type!(string_vec>(Postgres, == vec!["Hello, World", "", "Goodbye"] )); +test_type!(string_array<[String; 3]>(Postgres, + "array['one','two','three']::text[]" == ["one","two","three"], +)); + test_type!(i8( Postgres, "0::\"char\"" == 0_i8, @@ -91,6 +111,14 @@ test_type!(i32_vec>(Postgres, "'{1,3,-5}'::int[]" == vec![1_i32, 3, -5] )); +test_type!(i32_array_empty<[i32; 0]>(Postgres, + "'{}'::int[]" == [0_i32; 0], +)); + +test_type!(i32_array<[i32; 4]>(Postgres, + "'{5,10,50,100}'::int[]" == [5_i32, 10, 50, 100], +)); + test_type!(i64(Postgres, "9358295312::bigint" == 9358295312_i64)); test_type!(f32(Postgres, "9419.122::real" == 9419.122_f32)); @@ -339,12 +367,18 @@ mod json { "'[\"Hello\", \"World!\"]'::json" == json!(["Hello", "World!"]) )); - test_type!(json_array>( + test_type!(json_vec>( Postgres, "SELECT ({0}::jsonb[] is not distinct from $1::jsonb[])::int4, {0} as _2, $2 as _3", "array['\"😎\"'::json, '\"🙋‍♀️\"'::json]::json[]" == vec![json!("😎"), json!("🙋‍♀️")], )); + test_type!(json_array<[JsonValue; 2]>( + Postgres, + "SELECT ({0}::jsonb[] is not distinct from $1::jsonb[])::int4, {0} as _2, $2 as _3", + "array['\"😎\"'::json, '\"🙋‍♀️\"'::json]::json[]" == [json!("😎"), json!("🙋‍♀️")], + )); + test_type!(jsonb( Postgres, "'\"Hello, World\"'::jsonb" == json!("Hello, World"),