From e1f53b2d4b5298763f207af2f32466a45ca14d24 Mon Sep 17 00:00:00 2001 From: David Tolnay Date: Tue, 26 Dec 2023 17:45:17 -0800 Subject: [PATCH] Support deserializing &'de [u8; N] --- src/bytearray.rs | 41 +++++++++++++++++++++++++++++++++++++++++ src/de.rs | 10 ++++++++++ tests/test_derive.rs | 8 +++++++- 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/src/bytearray.rs b/src/bytearray.rs index 216b000..0d97d23 100644 --- a/src/bytearray.rs +++ b/src/bytearray.rs @@ -35,6 +35,7 @@ use serde::ser::{Serialize, Serializer}; /// # } /// ``` #[derive(Copy, Clone, Eq, Ord)] +#[cfg_attr(not(doc), repr(transparent))] pub struct ByteArray { bytes: [u8; N], } @@ -49,6 +50,10 @@ impl ByteArray { pub fn into_array(self) -> [u8; N] { self.bytes } + + fn from_ref(bytes: &[u8; N]) -> &Self { + unsafe { &*(bytes as *const [u8; N] as *const ByteArray) } + } } impl Debug for ByteArray { @@ -218,3 +223,39 @@ impl<'de, const N: usize> Deserialize<'de> for ByteArray { deserializer.deserialize_bytes(ByteArrayVisitor::) } } + +struct BorrowedByteArrayVisitor; + +impl<'de, const N: usize> Visitor<'de> for BorrowedByteArrayVisitor { + type Value = &'de ByteArray; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a borrowed byte array of length {}", N) + } + + fn visit_borrowed_bytes(self, v: &'de [u8]) -> Result + where + E: Error, + { + let borrowed_byte_array: &'de [u8; N] = v + .try_into() + .map_err(|_| E::invalid_length(v.len(), &self))?; + Ok(ByteArray::from_ref(borrowed_byte_array)) + } + + fn visit_borrowed_str(self, v: &'de str) -> Result + where + E: Error, + { + self.visit_borrowed_bytes(v.as_bytes()) + } +} + +impl<'a, 'de: 'a, const N: usize> Deserialize<'de> for &'a ByteArray { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_bytes(BorrowedByteArrayVisitor::) + } +} diff --git a/src/de.rs b/src/de.rs index 5933264..179c961 100644 --- a/src/de.rs +++ b/src/de.rs @@ -83,6 +83,16 @@ impl<'de, const N: usize> Deserialize<'de> for ByteArray { } } +impl<'de: 'a, 'a, const N: usize> Deserialize<'de> for &'a ByteArray { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + // Via the serde::Deserialize impl for &ByteArray. + serde::Deserialize::deserialize(deserializer) + } +} + #[cfg(any(feature = "std", feature = "alloc"))] impl<'de> Deserialize<'de> for ByteBuf { fn deserialize(deserializer: D) -> Result diff --git a/tests/test_derive.rs b/tests/test_derive.rs index dd818e4..57616e7 100644 --- a/tests/test_derive.rs +++ b/tests/test_derive.rs @@ -22,6 +22,9 @@ struct Test<'a> { #[serde(with = "serde_bytes")] byte_array: ByteArray<314>, + #[serde(with = "serde_bytes")] + borrowed_byte_array: &'a ByteArray<314>, + #[serde(with = "serde_bytes")] byte_buf: ByteBuf, @@ -67,6 +70,7 @@ fn test() { vec: b"...".to_vec(), bytes: Bytes::new(b"..."), byte_array: ByteArray::new([0; 314]), + borrowed_byte_array: &ByteArray::new([0; 314]), byte_buf: ByteBuf::from(b"...".as_ref()), cow_slice: Cow::Borrowed(b"..."), cow_bytes: Cow::Borrowed(Bytes::new(b"...")), @@ -84,7 +88,7 @@ fn test() { &[ Token::Struct { name: "Test", - len: 15, + len: 16, }, Token::Str("slice"), Token::BorrowedBytes(b"..."), @@ -96,6 +100,8 @@ fn test() { Token::BorrowedBytes(b"..."), Token::Str("byte_array"), Token::Bytes(&[0; 314]), + Token::Str("borrowed_byte_array"), + Token::BorrowedBytes(&[0; 314]), Token::Str("byte_buf"), Token::Bytes(b"..."), Token::Str("cow_slice"),