Skip to content

Commit

Permalink
fix: partial support flatten enum in struct
Browse files Browse the repository at this point in the history
Close #14.

I adapted the implementation from [ron-rs/ron#451](ron-rs/ron#451).

This is a workaround for Serde's internal buffer type used when deserializing via `visit_enum`.
  • Loading branch information
greenhat616 committed Oct 21, 2024
1 parent 2b30c30 commit 4037163
Show file tree
Hide file tree
Showing 18 changed files with 247 additions and 73 deletions.
147 changes: 113 additions & 34 deletions src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ impl<'de> Deserializer<'de> {
path: Path::Root,
remaining_depth: 128,
current_enum: None,
is_serde_content_newtype: false,
})?;
if let Some(parse_error) = document.error {
return Err(error::shared(parse_error));
Expand All @@ -130,6 +131,7 @@ impl<'de> Deserializer<'de> {
path: Path::Root,
remaining_depth: 128,
current_enum: None,
is_serde_content_newtype: false,
})?;
if let Some(parse_error) = document.error {
return Err(error::shared(parse_error));
Expand All @@ -142,7 +144,7 @@ impl<'de> Deserializer<'de> {
}
}

impl<'de> Iterator for Deserializer<'de> {
impl Iterator for Deserializer<'_> {
type Item = Self;

fn next(&mut self) -> Option<Self> {
Expand Down Expand Up @@ -434,6 +436,7 @@ struct DeserializerFromEvents<'de, 'document> {
path: Path<'document>,
remaining_depth: u8,
current_enum: Option<CurrentEnum<'document>>,
is_serde_content_newtype: bool,
}

#[derive(Copy, Clone)]
Expand Down Expand Up @@ -487,6 +490,7 @@ impl<'de, 'document> DeserializerFromEvents<'de, 'document> {
path: Path::Alias { parent: &self.path },
remaining_depth: self.remaining_depth,
current_enum: None,
is_serde_content_newtype: self.is_serde_content_newtype,
})
}
None => panic!("unresolved alias: {}", *pos),
Expand Down Expand Up @@ -649,7 +653,7 @@ struct SeqAccess<'de, 'document, 'seq> {
len: usize,
}

impl<'de, 'document, 'seq> de::SeqAccess<'de> for SeqAccess<'de, 'document, 'seq> {
impl<'de> de::SeqAccess<'de> for SeqAccess<'de, '_, '_> {
type Error = Error;

fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
Expand All @@ -672,6 +676,7 @@ impl<'de, 'document, 'seq> de::SeqAccess<'de> for SeqAccess<'de, 'document, 'seq
},
remaining_depth: self.de.remaining_depth,
current_enum: None,
is_serde_content_newtype: self.de.is_serde_content_newtype,
};
self.len += 1;
seed.deserialize(&mut element_de).map(Some)
Expand All @@ -687,7 +692,7 @@ struct MapAccess<'de, 'document, 'map> {
key: Option<&'document [u8]>,
}

impl<'de, 'document, 'map> de::MapAccess<'de> for MapAccess<'de, 'document, 'map> {
impl<'de> de::MapAccess<'de> for MapAccess<'de, '_, '_> {
type Error = Error;

fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
Expand Down Expand Up @@ -732,6 +737,7 @@ impl<'de, 'document, 'map> de::MapAccess<'de> for MapAccess<'de, 'document, 'map
},
remaining_depth: self.de.remaining_depth,
current_enum: None,
is_serde_content_newtype: self.de.is_serde_content_newtype,
};
seed.deserialize(&mut value_de)
}
Expand All @@ -741,9 +747,11 @@ struct EnumAccess<'de, 'document, 'variant> {
de: &'variant mut DeserializerFromEvents<'de, 'document>,
name: Option<&'static str>,
tag: &'document str,
/// a flag to do a hack to run visitor.visit_map() instead of visitor.visit_enum() when is_serde_content is true
has_visited: bool,
}

impl<'de, 'document, 'variant> de::EnumAccess<'de> for EnumAccess<'de, 'document, 'variant> {
impl<'de, 'variant> de::EnumAccess<'de> for EnumAccess<'de, '_, 'variant> {
type Error = Error;
type Variant = DeserializerFromEvents<'de, 'variant>;

Expand All @@ -763,12 +771,52 @@ impl<'de, 'document, 'variant> de::EnumAccess<'de> for EnumAccess<'de, 'document
name: self.name,
tag: self.tag,
}),
is_serde_content_newtype: self.de.is_serde_content_newtype,
};
Ok((variant, visitor))
}
}

impl<'de, 'document> de::VariantAccess<'de> for DeserializerFromEvents<'de, 'document> {
impl<'de> de::MapAccess<'de> for EnumAccess<'de, '_, '_> {
type Error = Error;

fn next_key_seed<K>(&mut self, seed: K) -> std::result::Result<Option<K::Value>, Self::Error>
where
K: DeserializeSeed<'de>,
{
if self.has_visited {
return Ok(None);
}
self.has_visited = true;
let str_de = StrDeserializer::<Error>::new(self.tag);
let variant = seed.deserialize(str_de)?;
Ok(Some(variant))
}

fn next_value_seed<V>(&mut self, seed: V) -> std::result::Result<V::Value, Self::Error>
where
V: DeserializeSeed<'de>,
{
let old_serde_content_newtype = mem::replace(&mut self.de.is_serde_content_newtype, true);
let mut visitor = DeserializerFromEvents {
document: self.de.document,
pos: self.de.pos,
jumpcount: self.de.jumpcount,
path: self.de.path,
remaining_depth: self.de.remaining_depth,
current_enum: Some(CurrentEnum {
name: self.name,
tag: self.tag,
}),
is_serde_content_newtype: self.de.is_serde_content_newtype,
};
let result = seed.deserialize(&mut visitor)?;
self.de.is_serde_content_newtype = old_serde_content_newtype;
Ok(result)
}
}

impl<'de> de::VariantAccess<'de> for DeserializerFromEvents<'de, '_> {
type Error = Error;

fn unit_variant(mut self) -> Result<()> {
Expand Down Expand Up @@ -801,7 +849,7 @@ struct UnitVariantAccess<'de, 'document, 'variant> {
de: &'variant mut DeserializerFromEvents<'de, 'document>,
}

impl<'de, 'document, 'variant> de::EnumAccess<'de> for UnitVariantAccess<'de, 'document, 'variant> {
impl<'de> de::EnumAccess<'de> for UnitVariantAccess<'de, '_, '_> {
type Error = Error;
type Variant = Self;

Expand All @@ -813,9 +861,7 @@ impl<'de, 'document, 'variant> de::EnumAccess<'de> for UnitVariantAccess<'de, 'd
}
}

impl<'de, 'document, 'variant> de::VariantAccess<'de>
for UnitVariantAccess<'de, 'document, 'variant>
{
impl<'de> de::VariantAccess<'de> for UnitVariantAccess<'de, '_, '_> {
type Error = Error;

fn unit_variant(self) -> Result<()> {
Expand Down Expand Up @@ -1163,7 +1209,7 @@ fn invalid_type(event: &Event, exp: &dyn Expected) -> Error {
exp: &'a dyn Expected,
}

impl<'de, 'a> Visitor<'de> for InvalidType<'a> {
impl Visitor<'_> for InvalidType<'_> {
type Value = Void;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
Expand Down Expand Up @@ -1195,7 +1241,7 @@ fn invalid_type(event: &Event, exp: &dyn Expected) -> Error {
}
}

fn parse_tag(libyaml_tag: &Option<Tag>) -> Option<&str> {
fn parse_tag(libyaml_tag: Option<&Tag>) -> Option<&str> {
let mut bytes: &[u8] = libyaml_tag.as_ref()?;
if let (b'!', rest) = bytes.split_first()? {
if !rest.is_empty() {
Expand All @@ -1207,7 +1253,7 @@ fn parse_tag(libyaml_tag: &Option<Tag>) -> Option<&str> {
}
}

impl<'de, 'document> de::Deserializer<'de> for &mut DeserializerFromEvents<'de, 'document> {
impl<'de> de::Deserializer<'de> for &mut DeserializerFromEvents<'de, '_> {
type Error = Error;

fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
Expand All @@ -1216,45 +1262,75 @@ impl<'de, 'document> de::Deserializer<'de> for &mut DeserializerFromEvents<'de,
{
let tagged_already = self.current_enum.is_some();
let (next, mark) = self.next_event_mark()?;
fn enum_tag(tag: &Option<Tag>, tagged_already: bool) -> Option<&str> {
fn enum_tag(tag: Option<&Tag>, tagged_already: bool) -> Option<&str> {
if tagged_already {
return None;
}
parse_tag(tag)
}
// TODO: switch to JSON enum semantics for JSON content
// Robust impl blocked on https://github.com/serde-rs/serde/pull/2420
let is_serde_content =
std::any::type_name::<V::Value>() == "serde::__private::de::content::Content";

let old_serde_content_newtype = mem::replace(&mut self.is_serde_content_newtype, false);
loop {
match next {
Event::Alias(mut pos) => break self.jump(&mut pos)?.deserialize_any(visitor),
Event::Scalar(scalar) => {
if let Some(tag) = enum_tag(&scalar.tag, tagged_already) {
if let Some(tag) = enum_tag(scalar.tag.as_ref(), tagged_already) {
*self.pos -= 1;
break visitor.visit_enum(EnumAccess {
de: self,
name: None,
tag,
});
break {
let access = EnumAccess {
de: self,
name: None,
tag,
has_visited: false,
};
if is_serde_content || old_serde_content_newtype {
visitor.visit_map(access)
} else {
visitor.visit_enum(access)
}
};
}
break visit_scalar(visitor, scalar, tagged_already);
}
Event::SequenceStart(sequence) => {
if let Some(tag) = enum_tag(&sequence.tag, tagged_already) {
if let Some(tag) = enum_tag(sequence.tag.as_ref(), tagged_already) {
*self.pos -= 1;
break visitor.visit_enum(EnumAccess {
de: self,
name: None,
tag,
});
break {
let access = EnumAccess {
de: self,
name: None,
tag,
has_visited: false,
};
if is_serde_content || old_serde_content_newtype {
visitor.visit_map(access)
} else {
visitor.visit_enum(access)
}
};
}
break self.visit_sequence(visitor, mark);
}
Event::MappingStart(mapping) => {
if let Some(tag) = enum_tag(&mapping.tag, tagged_already) {
if let Some(tag) = enum_tag(mapping.tag.as_ref(), tagged_already) {
*self.pos -= 1;
break visitor.visit_enum(EnumAccess {
de: self,
name: None,
tag,
});
break {
let access = EnumAccess {
de: self,
name: None,
tag,
has_visited: false,
};
if is_serde_content || old_serde_content_newtype {
visitor.visit_map(access)
} else {
visitor.visit_enum(access)
}
};
}
break self.visit_mapping(visitor, mark);
}
Expand Down Expand Up @@ -1744,33 +1820,36 @@ impl<'de, 'document> de::Deserializer<'de> for &mut DeserializerFromEvents<'de,
.deserialize_enum(name, variants, visitor)
}
Event::Scalar(scalar) => {
if let Some(tag) = parse_tag(&scalar.tag) {
if let Some(tag) = parse_tag(scalar.tag.as_ref()) {
return visitor.visit_enum(EnumAccess {
de: self,
name: Some(name),
tag,
has_visited: false,
});
}
visitor.visit_enum(UnitVariantAccess { de: self })
}
Event::MappingStart(mapping) => {
if let Some(tag) = parse_tag(&mapping.tag) {
if let Some(tag) = parse_tag(mapping.tag.as_ref()) {
return visitor.visit_enum(EnumAccess {
de: self,
name: Some(name),
tag,
has_visited: false,
});
}
let err =
de::Error::invalid_type(Unexpected::Map, &"a YAML tag starting with '!'");
Err(error::fix_mark(err, mark, self.path))
}
Event::SequenceStart(sequence) => {
if let Some(tag) = parse_tag(&sequence.tag) {
if let Some(tag) = parse_tag(sequence.tag.as_ref()) {
return visitor.visit_enum(EnumAccess {
de: self,
name: Some(name),
tag,
has_visited: false,
});
}
let err =
Expand Down
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ impl ErrorImpl {
_ => {
f.write_str("Error(")?;
struct MessageNoMark<'a>(&'a ErrorImpl);
impl<'a> Display for MessageNoMark<'a> {
impl Display for MessageNoMark<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.message_no_mark(f)
}
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,5 +189,5 @@ mod private {
impl Sealed for str {}
impl Sealed for String {}
impl Sealed for crate::Value {}
impl<'a, T> Sealed for &'a T where T: ?Sized + Sealed {}
impl<T> Sealed for &T where T: ?Sized + Sealed {}
}
8 changes: 4 additions & 4 deletions src/libyaml/cstr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ pub(crate) struct CStr<'a> {
marker: PhantomData<&'a [u8]>,
}

unsafe impl<'a> Send for CStr<'a> {}
unsafe impl<'a> Sync for CStr<'a> {}
unsafe impl Send for CStr<'_> {}
unsafe impl Sync for CStr<'_> {}

impl<'a> CStr<'a> {
pub fn from_bytes_with_nul(bytes: &'static [u8]) -> Self {
Expand Down Expand Up @@ -44,7 +44,7 @@ impl<'a> CStr<'a> {
}
}

impl<'a> Display for CStr<'a> {
impl Display for CStr<'_> {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
let ptr = self.ptr.as_ptr();
let len = self.len();
Expand All @@ -53,7 +53,7 @@ impl<'a> Display for CStr<'a> {
}
}

impl<'a> Debug for CStr<'a> {
impl Debug for CStr<'_> {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
let ptr = self.ptr.as_ptr();
let len = self.len();
Expand Down
2 changes: 1 addition & 1 deletion src/libyaml/emitter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ unsafe fn write_handler(data: *mut c_void, buffer: *mut u8, size: u64) -> i32 {
}
}

impl<'a> Drop for EmitterPinned<'a> {
impl Drop for EmitterPinned<'_> {
fn drop(&mut self) {
unsafe { sys::yaml_emitter_delete(&mut self.sys) }
}
Expand Down
Loading

0 comments on commit 4037163

Please sign in to comment.