Skip to content

Commit

Permalink
fix: partial support flatten enum in struct
Browse files Browse the repository at this point in the history
Close #14.

I adapted the implementation from [ron-rs/ron#451](ron-rs/ron#451).

This is a workaround for Serde's internal buffer type used when deserializing via `visit_enum`.
  • Loading branch information
greenhat616 committed Oct 21, 2024
1 parent b77a147 commit 35bb238
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 18 deletions.
111 changes: 96 additions & 15 deletions src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ impl<'de> Deserializer<'de> {
path: Path::Root,
remaining_depth: 128,
current_enum: None,
is_serde_content_newtype: false,
})?;
if let Some(parse_error) = document.error {
return Err(error::shared(parse_error));
Expand All @@ -130,6 +131,7 @@ impl<'de> Deserializer<'de> {
path: Path::Root,
remaining_depth: 128,
current_enum: None,
is_serde_content_newtype: false,
})?;
if let Some(parse_error) = document.error {
return Err(error::shared(parse_error));
Expand Down Expand Up @@ -434,6 +436,7 @@ struct DeserializerFromEvents<'de, 'document> {
path: Path<'document>,
remaining_depth: u8,
current_enum: Option<CurrentEnum<'document>>,
is_serde_content_newtype: bool,
}

#[derive(Copy, Clone)]
Expand Down Expand Up @@ -487,6 +490,7 @@ impl<'de, 'document> DeserializerFromEvents<'de, 'document> {
path: Path::Alias { parent: &self.path },
remaining_depth: self.remaining_depth,
current_enum: None,
is_serde_content_newtype: self.is_serde_content_newtype,
})
}
None => panic!("unresolved alias: {}", *pos),
Expand Down Expand Up @@ -672,6 +676,7 @@ impl<'de> de::SeqAccess<'de> for SeqAccess<'de, '_, '_> {
},
remaining_depth: self.de.remaining_depth,
current_enum: None,
is_serde_content_newtype: self.de.is_serde_content_newtype,
};
self.len += 1;
seed.deserialize(&mut element_de).map(Some)
Expand Down Expand Up @@ -732,6 +737,7 @@ impl<'de> de::MapAccess<'de> for MapAccess<'de, '_, '_> {
},
remaining_depth: self.de.remaining_depth,
current_enum: None,
is_serde_content_newtype: self.de.is_serde_content_newtype,
};
seed.deserialize(&mut value_de)
}
Expand All @@ -741,6 +747,8 @@ struct EnumAccess<'de, 'document, 'variant> {
de: &'variant mut DeserializerFromEvents<'de, 'document>,
name: Option<&'static str>,
tag: &'document str,
/// a flag to do a hack to run visitor.visit_map() instead of visitor.visit_enum() when is_serde_content is true
has_visited: bool,
}

impl<'de, 'variant> de::EnumAccess<'de> for EnumAccess<'de, '_, 'variant> {
Expand All @@ -763,11 +771,51 @@ impl<'de, 'variant> de::EnumAccess<'de> for EnumAccess<'de, '_, 'variant> {
name: self.name,
tag: self.tag,
}),
is_serde_content_newtype: self.de.is_serde_content_newtype,
};
Ok((variant, visitor))
}
}

impl<'de> de::MapAccess<'de> for EnumAccess<'de, '_, '_> {
type Error = Error;

fn next_key_seed<K>(&mut self, seed: K) -> std::result::Result<Option<K::Value>, Self::Error>
where
K: DeserializeSeed<'de>,
{
if self.has_visited {
return Ok(None);
}
self.has_visited = true;
let str_de = StrDeserializer::<Error>::new(self.tag);
let variant = seed.deserialize(str_de)?;
Ok(Some(variant))
}

fn next_value_seed<V>(&mut self, seed: V) -> std::result::Result<V::Value, Self::Error>
where
V: DeserializeSeed<'de>,
{
let old_serde_content_newtype = mem::replace(&mut self.de.is_serde_content_newtype, true);
let mut visitor = DeserializerFromEvents {
document: self.de.document,
pos: self.de.pos,
jumpcount: self.de.jumpcount,
path: self.de.path,
remaining_depth: self.de.remaining_depth,
current_enum: Some(CurrentEnum {
name: self.name,
tag: self.tag,
}),
is_serde_content_newtype: self.de.is_serde_content_newtype,
};
let result = seed.deserialize(&mut visitor)?;
self.de.is_serde_content_newtype = old_serde_content_newtype;
Ok(result)
}
}

impl<'de> de::VariantAccess<'de> for DeserializerFromEvents<'de, '_> {
type Error = Error;

Expand Down Expand Up @@ -1220,39 +1268,69 @@ impl<'de> de::Deserializer<'de> for &mut DeserializerFromEvents<'de, '_> {
}
parse_tag(tag)
}
// TODO: switch to JSON enum semantics for JSON content
// Robust impl blocked on https://github.com/serde-rs/serde/pull/2420
let is_serde_content =
std::any::type_name::<V::Value>() == "serde::__private::de::content::Content";

let old_serde_content_newtype = mem::replace(&mut self.is_serde_content_newtype, false);
loop {
match next {
Event::Alias(mut pos) => break self.jump(&mut pos)?.deserialize_any(visitor),
Event::Scalar(scalar) => {
if let Some(tag) = enum_tag(scalar.tag.as_ref(), tagged_already) {
*self.pos -= 1;
break visitor.visit_enum(EnumAccess {
de: self,
name: None,
tag,
});
break {
let access = EnumAccess {
de: self,
name: None,
tag,
has_visited: false,
};
if is_serde_content || old_serde_content_newtype {
visitor.visit_map(access)
} else {
visitor.visit_enum(access)
}
};
}
break visit_scalar(visitor, scalar, tagged_already);
}
Event::SequenceStart(sequence) => {
if let Some(tag) = enum_tag(sequence.tag.as_ref(), tagged_already) {
*self.pos -= 1;
break visitor.visit_enum(EnumAccess {
de: self,
name: None,
tag,
});
break {
let access = EnumAccess {
de: self,
name: None,
tag,
has_visited: false,
};
if is_serde_content || old_serde_content_newtype {
visitor.visit_map(access)
} else {
visitor.visit_enum(access)
}
};
}
break self.visit_sequence(visitor, mark);
}
Event::MappingStart(mapping) => {
if let Some(tag) = enum_tag(mapping.tag.as_ref(), tagged_already) {
*self.pos -= 1;
break visitor.visit_enum(EnumAccess {
de: self,
name: None,
tag,
});
break {
let access = EnumAccess {
de: self,
name: None,
tag,
has_visited: false,
};
if is_serde_content || old_serde_content_newtype {
visitor.visit_map(access)
} else {
visitor.visit_enum(access)
}
};
}
break self.visit_mapping(visitor, mark);
}
Expand Down Expand Up @@ -1747,6 +1825,7 @@ impl<'de> de::Deserializer<'de> for &mut DeserializerFromEvents<'de, '_> {
de: self,
name: Some(name),
tag,
has_visited: false,
});
}
visitor.visit_enum(UnitVariantAccess { de: self })
Expand All @@ -1757,6 +1836,7 @@ impl<'de> de::Deserializer<'de> for &mut DeserializerFromEvents<'de, '_> {
de: self,
name: Some(name),
tag,
has_visited: false,
});
}
let err =
Expand All @@ -1769,6 +1849,7 @@ impl<'de> de::Deserializer<'de> for &mut DeserializerFromEvents<'de, '_> {
de: self,
name: Some(name),
tag,
has_visited: false,
});
}
let err =
Expand Down
22 changes: 20 additions & 2 deletions src/value/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use std::fmt;
use std::slice;
use std::vec;

use super::tagged::TaggedValueMapAccess;

impl<'de> Deserialize<'de> for Value {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
Expand Down Expand Up @@ -199,14 +201,22 @@ impl<'de> Deserializer<'de> for Value {
where
V: Visitor<'de>,
{
let is_serde_value =
std::any::type_name::<V::Value>() == "serde::__private::de::content::Content";
match self {
Value::Null => visitor.visit_unit(),
Value::Bool(v) => visitor.visit_bool(v),
Value::Number(n) => n.deserialize_any(visitor),
Value::String(v) => visitor.visit_string(v),
Value::Sequence(v) => visit_sequence(v, visitor),
Value::Mapping(v) => visit_mapping(v, visitor),
Value::Tagged(tagged) => visitor.visit_enum(*tagged),
Value::Tagged(tagged) => {
if is_serde_value {
visitor.visit_map(TaggedValueMapAccess::from(*tagged))
} else {
visitor.visit_enum(*tagged)
}
}
}
}

Expand Down Expand Up @@ -716,14 +726,22 @@ impl<'de> Deserializer<'de> for &'de Value {
where
V: Visitor<'de>,
{
let is_serde_content =
std::any::type_name::<V::Value>() == "serde::__private::de::content::Content";
match self {
Value::Null => visitor.visit_unit(),
Value::Bool(v) => visitor.visit_bool(*v),
Value::Number(n) => n.deserialize_any(visitor),
Value::String(v) => visitor.visit_borrowed_str(v),
Value::Sequence(v) => visit_sequence_ref(v, visitor),
Value::Mapping(v) => visit_mapping_ref(v, visitor),
Value::Tagged(tagged) => visitor.visit_enum(&**tagged),
Value::Tagged(tagged) => {
if is_serde_content {
visitor.visit_map(TaggedValueMapAccess::from((**tagged).clone()))
} else {
visitor.visit_enum(&**tagged)
}
}
}
}

Expand Down
35 changes: 34 additions & 1 deletion src/value/tagged.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use crate::value::Value;
use crate::Error;
use serde::de::value::{BorrowedStrDeserializer, StrDeserializer};
use serde::de::{
Deserialize, DeserializeSeed, Deserializer, EnumAccess, Error as _, VariantAccess, Visitor,
Deserialize, DeserializeSeed, Deserializer, EnumAccess, Error as _, MapAccess, VariantAccess,
Visitor,
};
use serde::forward_to_deserialize_any;
use serde::ser::{Serialize, SerializeMap, Serializer};
Expand Down Expand Up @@ -260,6 +261,38 @@ impl<'de> EnumAccess<'de> for TaggedValue {
}
}

pub struct TaggedValueMapAccess {
inner: Option<TaggedValue>,
}

impl From<TaggedValue> for TaggedValueMapAccess {
fn from(inner: TaggedValue) -> Self {
TaggedValueMapAccess { inner: Some(inner) }
}
}

impl<'de> MapAccess<'de> for TaggedValueMapAccess {
type Error = Error;

fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Error>
where
K: DeserializeSeed<'de>,
{
if self.inner.is_none() {
return Ok(None);
}
let tag = StrDeserializer::<Error>::new(nobang(&self.inner.as_ref().unwrap().tag.string));
seed.deserialize(tag).map(Some)
}

fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Error>
where
V: DeserializeSeed<'de>,
{
seed.deserialize(self.inner.take().unwrap().value.clone())
}
}

impl<'de> VariantAccess<'de> for Value {
type Error = Error;

Expand Down
44 changes: 44 additions & 0 deletions tests/test_de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ where
assert_eq!(*expected, deserialized);

let value: Value = serde_yaml_ng::from_str(yaml).unwrap();
dbg!("{:?}", &value);
let deserialized = T::deserialize(&value).unwrap();
assert_eq!(*expected, deserialized);

Expand Down Expand Up @@ -242,6 +243,49 @@ fn test_enum_representations() {
test_de_no_value(yaml, &expected);
}

#[test]
fn test_enum_outer_flatten() {
#[derive(Deserialize, PartialEq, Debug)]
enum Enum {
A,
B,
C(String),
}
#[derive(Deserialize, PartialEq, Debug)]
struct Inner {
a: Enum,
}
#[derive(Deserialize, PartialEq, Debug)]
struct Outer {
#[serde(flatten)]
inner: Inner,
}

let yaml: &str = indoc! {"
a: !C x
"};
let expected = Outer {
inner: Inner {
a: Enum::C("x".to_owned()),
},
};
test_de(yaml, &expected);
let yaml: &str = indoc! {"
a: !C
"};
// let expected = Outer {
// inner: Inner {
// a: Enum::C(String::new()),
// },
// };
// test_de_no_value(yaml, &expected);
// This should fail. Blocked by by https://github.com/serde-rs/serde/issues/1183
assert!(serde_yaml_ng::from_str::<Outer>(yaml).is_err());

serde_yaml_ng::from_str::<serde_yaml_ng::Value>(yaml).unwrap();
serde_yaml_ng::from_str::<serde::de::IgnoredAny>(yaml).unwrap();
}

#[test]
fn test_number_as_string() {
#[derive(Deserialize, PartialEq, Debug)]
Expand Down

0 comments on commit 35bb238

Please sign in to comment.