diff --git a/src/de.rs b/src/de.rs index 12e8516..d45b93b 100644 --- a/src/de.rs +++ b/src/de.rs @@ -109,6 +109,7 @@ impl<'de> Deserializer<'de> { path: Path::Root, remaining_depth: 128, current_enum: None, + is_serde_content_newtype: false, })?; if let Some(parse_error) = document.error { return Err(error::shared(parse_error)); @@ -130,6 +131,7 @@ impl<'de> Deserializer<'de> { path: Path::Root, remaining_depth: 128, current_enum: None, + is_serde_content_newtype: false, })?; if let Some(parse_error) = document.error { return Err(error::shared(parse_error)); @@ -434,6 +436,7 @@ struct DeserializerFromEvents<'de, 'document> { path: Path<'document>, remaining_depth: u8, current_enum: Option>, + is_serde_content_newtype: bool, } #[derive(Copy, Clone)] @@ -487,6 +490,7 @@ impl<'de, 'document> DeserializerFromEvents<'de, 'document> { path: Path::Alias { parent: &self.path }, remaining_depth: self.remaining_depth, current_enum: None, + is_serde_content_newtype: self.is_serde_content_newtype, }) } None => panic!("unresolved alias: {}", *pos), @@ -672,6 +676,7 @@ impl<'de> de::SeqAccess<'de> for SeqAccess<'de, '_, '_> { }, remaining_depth: self.de.remaining_depth, current_enum: None, + is_serde_content_newtype: self.de.is_serde_content_newtype, }; self.len += 1; seed.deserialize(&mut element_de).map(Some) @@ -732,6 +737,7 @@ impl<'de> de::MapAccess<'de> for MapAccess<'de, '_, '_> { }, remaining_depth: self.de.remaining_depth, current_enum: None, + is_serde_content_newtype: self.de.is_serde_content_newtype, }; seed.deserialize(&mut value_de) } @@ -741,6 +747,8 @@ struct EnumAccess<'de, 'document, 'variant> { de: &'variant mut DeserializerFromEvents<'de, 'document>, name: Option<&'static str>, tag: &'document str, + /// a flag to do a hack to run visitor.visit_map() instead of visitor.visit_enum() when is_serde_content is true + has_visited: bool, } impl<'de, 'variant> de::EnumAccess<'de> for EnumAccess<'de, '_, 'variant> { @@ -763,11 +771,51 @@ impl<'de, 'variant> de::EnumAccess<'de> for EnumAccess<'de, '_, 'variant> { name: self.name, tag: self.tag, }), + is_serde_content_newtype: self.de.is_serde_content_newtype, }; Ok((variant, visitor)) } } +impl<'de> de::MapAccess<'de> for EnumAccess<'de, '_, '_> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> std::result::Result, Self::Error> + where + K: DeserializeSeed<'de>, + { + if self.has_visited { + return Ok(None); + } + self.has_visited = true; + let str_de = StrDeserializer::::new(self.tag); + let variant = seed.deserialize(str_de)?; + Ok(Some(variant)) + } + + fn next_value_seed(&mut self, seed: V) -> std::result::Result + where + V: DeserializeSeed<'de>, + { + let old_serde_content_newtype = mem::replace(&mut self.de.is_serde_content_newtype, true); + let mut visitor = DeserializerFromEvents { + document: self.de.document, + pos: self.de.pos, + jumpcount: self.de.jumpcount, + path: self.de.path, + remaining_depth: self.de.remaining_depth, + current_enum: Some(CurrentEnum { + name: self.name, + tag: self.tag, + }), + is_serde_content_newtype: self.de.is_serde_content_newtype, + }; + let result = seed.deserialize(&mut visitor)?; + self.de.is_serde_content_newtype = old_serde_content_newtype; + Ok(result) + } +} + impl<'de> de::VariantAccess<'de> for DeserializerFromEvents<'de, '_> { type Error = Error; @@ -1220,39 +1268,69 @@ impl<'de> de::Deserializer<'de> for &mut DeserializerFromEvents<'de, '_> { } parse_tag(tag) } + // TODO: switch to JSON enum semantics for JSON content + // Robust impl blocked on https://github.com/serde-rs/serde/pull/2420 + let is_serde_content = + std::any::type_name::() == "serde::__private::de::content::Content"; + + let old_serde_content_newtype = mem::replace(&mut self.is_serde_content_newtype, false); loop { match next { Event::Alias(mut pos) => break self.jump(&mut pos)?.deserialize_any(visitor), Event::Scalar(scalar) => { if let Some(tag) = enum_tag(scalar.tag.as_ref(), tagged_already) { *self.pos -= 1; - break visitor.visit_enum(EnumAccess { - de: self, - name: None, - tag, - }); + break { + let access = EnumAccess { + de: self, + name: None, + tag, + has_visited: false, + }; + if is_serde_content || old_serde_content_newtype { + visitor.visit_map(access) + } else { + visitor.visit_enum(access) + } + }; } break visit_scalar(visitor, scalar, tagged_already); } Event::SequenceStart(sequence) => { if let Some(tag) = enum_tag(sequence.tag.as_ref(), tagged_already) { *self.pos -= 1; - break visitor.visit_enum(EnumAccess { - de: self, - name: None, - tag, - }); + break { + let access = EnumAccess { + de: self, + name: None, + tag, + has_visited: false, + }; + if is_serde_content || old_serde_content_newtype { + visitor.visit_map(access) + } else { + visitor.visit_enum(access) + } + }; } break self.visit_sequence(visitor, mark); } Event::MappingStart(mapping) => { if let Some(tag) = enum_tag(mapping.tag.as_ref(), tagged_already) { *self.pos -= 1; - break visitor.visit_enum(EnumAccess { - de: self, - name: None, - tag, - }); + break { + let access = EnumAccess { + de: self, + name: None, + tag, + has_visited: false, + }; + if is_serde_content || old_serde_content_newtype { + visitor.visit_map(access) + } else { + visitor.visit_enum(access) + } + }; } break self.visit_mapping(visitor, mark); } @@ -1747,6 +1825,7 @@ impl<'de> de::Deserializer<'de> for &mut DeserializerFromEvents<'de, '_> { de: self, name: Some(name), tag, + has_visited: false, }); } visitor.visit_enum(UnitVariantAccess { de: self }) @@ -1757,6 +1836,7 @@ impl<'de> de::Deserializer<'de> for &mut DeserializerFromEvents<'de, '_> { de: self, name: Some(name), tag, + has_visited: false, }); } let err = @@ -1769,6 +1849,7 @@ impl<'de> de::Deserializer<'de> for &mut DeserializerFromEvents<'de, '_> { de: self, name: Some(name), tag, + has_visited: false, }); } let err = diff --git a/src/value/de.rs b/src/value/de.rs index 7b56c18..70af8b2 100644 --- a/src/value/de.rs +++ b/src/value/de.rs @@ -11,6 +11,8 @@ use std::fmt; use std::slice; use std::vec; +use super::tagged::TaggedValueMapAccess; + impl<'de> Deserialize<'de> for Value { fn deserialize(deserializer: D) -> Result where @@ -199,6 +201,8 @@ impl<'de> Deserializer<'de> for Value { where V: Visitor<'de>, { + let is_serde_value = + std::any::type_name::() == "serde::__private::de::content::Content"; match self { Value::Null => visitor.visit_unit(), Value::Bool(v) => visitor.visit_bool(v), @@ -206,7 +210,13 @@ impl<'de> Deserializer<'de> for Value { Value::String(v) => visitor.visit_string(v), Value::Sequence(v) => visit_sequence(v, visitor), Value::Mapping(v) => visit_mapping(v, visitor), - Value::Tagged(tagged) => visitor.visit_enum(*tagged), + Value::Tagged(tagged) => { + if is_serde_value { + visitor.visit_map(TaggedValueMapAccess::from(*tagged)) + } else { + visitor.visit_enum(*tagged) + } + } } } @@ -716,6 +726,8 @@ impl<'de> Deserializer<'de> for &'de Value { where V: Visitor<'de>, { + let is_serde_content = + std::any::type_name::() == "serde::__private::de::content::Content"; match self { Value::Null => visitor.visit_unit(), Value::Bool(v) => visitor.visit_bool(*v), @@ -723,7 +735,13 @@ impl<'de> Deserializer<'de> for &'de Value { Value::String(v) => visitor.visit_borrowed_str(v), Value::Sequence(v) => visit_sequence_ref(v, visitor), Value::Mapping(v) => visit_mapping_ref(v, visitor), - Value::Tagged(tagged) => visitor.visit_enum(&**tagged), + Value::Tagged(tagged) => { + if is_serde_content { + visitor.visit_map(TaggedValueMapAccess::from((**tagged).clone())) + } else { + visitor.visit_enum(&**tagged) + } + } } } diff --git a/src/value/tagged.rs b/src/value/tagged.rs index d83f053..8ac6009 100644 --- a/src/value/tagged.rs +++ b/src/value/tagged.rs @@ -3,7 +3,8 @@ use crate::value::Value; use crate::Error; use serde::de::value::{BorrowedStrDeserializer, StrDeserializer}; use serde::de::{ - Deserialize, DeserializeSeed, Deserializer, EnumAccess, Error as _, VariantAccess, Visitor, + Deserialize, DeserializeSeed, Deserializer, EnumAccess, Error as _, MapAccess, VariantAccess, + Visitor, }; use serde::forward_to_deserialize_any; use serde::ser::{Serialize, SerializeMap, Serializer}; @@ -260,6 +261,38 @@ impl<'de> EnumAccess<'de> for TaggedValue { } } +pub struct TaggedValueMapAccess { + inner: Option, +} + +impl From for TaggedValueMapAccess { + fn from(inner: TaggedValue) -> Self { + TaggedValueMapAccess { inner: Some(inner) } + } +} + +impl<'de> MapAccess<'de> for TaggedValueMapAccess { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result, Error> + where + K: DeserializeSeed<'de>, + { + if self.inner.is_none() { + return Ok(None); + } + let tag = StrDeserializer::::new(nobang(&self.inner.as_ref().unwrap().tag.string)); + seed.deserialize(tag).map(Some) + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'de>, + { + seed.deserialize(self.inner.take().unwrap().value.clone()) + } +} + impl<'de> VariantAccess<'de> for Value { type Error = Error; diff --git a/tests/test_de.rs b/tests/test_de.rs index 774db07..bb30cd3 100644 --- a/tests/test_de.rs +++ b/tests/test_de.rs @@ -20,6 +20,7 @@ where assert_eq!(*expected, deserialized); let value: Value = serde_yaml_ng::from_str(yaml).unwrap(); + dbg!("{:?}", &value); let deserialized = T::deserialize(&value).unwrap(); assert_eq!(*expected, deserialized); @@ -242,6 +243,49 @@ fn test_enum_representations() { test_de_no_value(yaml, &expected); } +#[test] +fn test_enum_outer_flatten() { + #[derive(Deserialize, PartialEq, Debug)] + enum Enum { + A, + B, + C(String), + } + #[derive(Deserialize, PartialEq, Debug)] + struct Inner { + a: Enum, + } + #[derive(Deserialize, PartialEq, Debug)] + struct Outer { + #[serde(flatten)] + inner: Inner, + } + + let yaml: &str = indoc! {" + a: !C x + "}; + let expected = Outer { + inner: Inner { + a: Enum::C("x".to_owned()), + }, + }; + test_de(yaml, &expected); + let yaml: &str = indoc! {" + a: !C + "}; + // let expected = Outer { + // inner: Inner { + // a: Enum::C(String::new()), + // }, + // }; + // test_de_no_value(yaml, &expected); + // This should fail. Blocked by by https://github.com/serde-rs/serde/issues/1183 + assert!(serde_yaml_ng::from_str::(yaml).is_err()); + + serde_yaml_ng::from_str::(yaml).unwrap(); + serde_yaml_ng::from_str::(yaml).unwrap(); +} + #[test] fn test_number_as_string() { #[derive(Deserialize, PartialEq, Debug)]