Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize internally tagged enums -- do not use internal buffer if tag is the first field #1922

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 75 additions & 4 deletions serde/src/de/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
use crate::lib::*;

use self::private::{First, Second};
use crate::de::{self, size_hint, Deserializer, Expected, IntoDeserializer, SeqAccess, Visitor};
use crate::de::{
self, size_hint, Deserializer, Expected, IgnoredAny, IntoDeserializer, SeqAccess, Visitor,
};
use crate::ser;

////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -978,7 +980,9 @@ where
}
}

struct ExpectedInSeq(usize);
/// Number of elements still expected in a sequence. Does not include already
/// read elements.
pub(crate) struct ExpectedInSeq(pub usize);

impl Expected for ExpectedInSeq {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
Expand Down Expand Up @@ -1076,9 +1080,38 @@ where
visitor.visit_seq(self.seq)
}

fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
// Covered by tests/test_enum_internally_tagged.rs
// newtype_unit
visitor.visit_unit()
}

fn deserialize_unit_struct<V>(
self,
_name: &'static str,
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
// Covered by tests/test_enum_internally_tagged.rs
// newtype_unit_struct
self.deserialize_unit(visitor)
}

fn deserialize_newtype_struct<V>(self, _name: &str, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
visitor.visit_newtype_struct(self)
}

Comment on lines +1083 to +1111
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, should we change behavior of SeqAccessDeserializer and MapAccessDeserializer or introduce new private deserializers? From one hand those deserializers was created for support of various serde attributes. From the other hand, technically this is breaking change because those types are public.

forward_to_deserialize_any! {
bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
bytes byte_buf option unit unit_struct newtype_struct seq tuple
bytes byte_buf option seq tuple
tuple_struct map struct enum identifier ignored_any
}
}
Expand Down Expand Up @@ -1406,6 +1439,8 @@ where
}
}

/// Number of elements still expected in a map. Does not include already read
/// elements.
struct ExpectedInMap(usize);

impl Expected for ExpectedInMap {
Expand Down Expand Up @@ -1479,6 +1514,42 @@ where
visitor.visit_map(self.map)
}

fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
// Covered by tests/test_enum_internally_tagged.rs
// newtype_unit
tri!(IgnoredAny.visit_map(self.map));
visitor.visit_unit()
}

fn deserialize_unit_struct<V>(
self,
_name: &'static str,
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
// Covered by tests/test_enum_internally_tagged.rs
// newtype_unit_struct
self.deserialize_unit(visitor)
}

fn deserialize_newtype_struct<V>(
self,
_name: &'static str,
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
// Covered by tests/test_enum_internally_tagged.rs
// newtype_newtype
visitor.visit_newtype_struct(self)
}

fn deserialize_enum<V>(
self,
_name: &str,
Expand All @@ -1493,7 +1564,7 @@ where

forward_to_deserialize_any! {
bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
bytes byte_buf option unit unit_struct newtype_struct seq tuple
bytes byte_buf option seq tuple
tuple_struct map struct identifier ignored_any
}
}
Expand Down
83 changes: 55 additions & 28 deletions serde/src/private/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ mod content {
use crate::lib::*;

use crate::actually_private;
use crate::de::value::{MapDeserializer, SeqDeserializer};
use crate::de::value::{
ExpectedInSeq, MapAccessDeserializer, MapDeserializer, SeqDeserializer,
};
use crate::de::{
self, size_hint, Deserialize, DeserializeSeed, Deserializer, EnumAccess, Expected,
IgnoredAny, MapAccess, SeqAccess, Unexpected, Visitor,
Expand Down Expand Up @@ -536,9 +538,7 @@ mod content {
}

/// This is the type of the map keys in an internally tagged enum.
///
/// Not public API.
pub enum TagOrContent<'de> {
enum TagOrContent<'de> {
Tag,
Content(Content<'de>),
}
Expand Down Expand Up @@ -855,9 +855,9 @@ mod content {

impl<'de, T> Visitor<'de> for TaggedContentVisitor<T>
where
T: Deserialize<'de>,
T: Deserialize<'de> + DeserializeSeed<'de>,
{
type Value = (T, Content<'de>);
type Value = T::Value;

fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.write_str(self.expecting)
Expand All @@ -867,42 +867,63 @@ mod content {
where
S: SeqAccess<'de>,
{
let tag = match tri!(seq.next_element()) {
let tag: T = match tri!(seq.next_element()) {
Some(tag) => tag,
None => {
return Err(de::Error::missing_field(self.tag_name));
}
};
let rest = de::value::SeqAccessDeserializer::new(seq);
Ok((tag, tri!(Content::deserialize(rest))))
tag.deserialize(de::value::SeqAccessDeserializer::new(seq))
}

fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
where
M: MapAccess<'de>,
{
let mut tag = None;
let mut vec = Vec::<(Content, Content)>::with_capacity(size_hint::cautious::<(
Content,
Content,
)>(map.size_hint()));
while let Some(k) = tri!(map.next_key_seed(TagOrContentVisitor::new(self.tag_name))) {
match k {
TagOrContent::Tag => {
if tag.is_some() {
return Err(de::Error::duplicate_field(self.tag_name));
// Read the first field. If it is a tag, immediately deserialize the typed data.
// Otherwise, we collect everything until we find the tag, and then deserialize
// using ContentDeserializer.
match tri!(map.next_key_seed(TagOrContentVisitor::new(self.tag_name))) {
Some(TagOrContent::Tag) => {
let tag: T = tri!(map.next_value());
tag.deserialize(MapAccessDeserializer::new(map))
}
Some(TagOrContent::Content(key)) => {
let mut tag = None::<T>;
let mut vec = Vec::<(Content, Content)>::with_capacity(size_hint::cautious::<(
Content,
Content,
)>(
map.size_hint()
));

let v = tri!(map.next_value());
vec.push((key, v));

while let Some(k) =
tri!(map.next_key_seed(TagOrContentVisitor::new(self.tag_name)))
{
match k {
TagOrContent::Tag => {
if tag.is_some() {
return Err(de::Error::duplicate_field(self.tag_name));
}
tag = Some(tri!(map.next_value()));
}
TagOrContent::Content(k) => {
let v = tri!(map.next_value());
vec.push((k, v));
}
}
tag = Some(tri!(map.next_value()));
}
TagOrContent::Content(k) => {
let v = tri!(map.next_value());
vec.push((k, v));
match tag {
None => Err(de::Error::missing_field(self.tag_name)),
Some(tag) => {
tag.deserialize(ContentDeserializer::<M::Error>::new(Content::Map(vec)))
}
}
}
}
match tag {
None => Err(de::Error::missing_field(self.tag_name)),
Some(tag) => Ok((tag, Content::Map(vec))),
}
}
}
Expand Down Expand Up @@ -2296,11 +2317,17 @@ mod content {
)
}

fn visit_seq<S>(self, _: S) -> Result<(), S::Error>
fn visit_seq<S>(self, mut seq: S) -> Result<(), S::Error>
where
S: SeqAccess<'de>,
{
Ok(())
match tri!(seq.next_element()) {
Copy link

@RReverser RReverser Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This behaves quite differently from IgnoredAny.visit_map. I think the behaviour should be consistent, as in, iterate over the entire sequence and ignore its values instead of erroring out on non-empty sequence.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried that initially, but that failed other tests and in general not what you want. The unit / unit struct represented in sequence as nothing, so we need to ensure that sequence is empty. This is consistent with normal behavior where struct deserialization from a sequence expects exact number of values, and those fact that flattened unit / unit struct considered as equal to the struct without fields.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the unit_variant_with_unknown_fields is a test that failed if consume the whole sequence here.

// Unknown elements are not allowed in sequences
assert_de_tokens_error::<InternallyTagged>(
&[
Token::Seq { len: None },
Token::Str("Unit"), // tag
Token::I32(0),
Token::SeqEnd,
],
"invalid length 1, expected 0 elements in sequence",
);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unit / unit struct represented in sequence as nothing

Hm but "nothing" should be pretty different conceptually from "ignored any". I'd expect a custom check just for the nothing case, whereas ignored any should be able to consume anything thrown at it silently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code tries to read something, doesn't matter what. We expect an empty sequence, so if it contains some element, we fail.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, I'm sleepy - I thought you're changing how IgnoredAny works everywhere. I've expanded the context of the diff and I see this is a change on this one specific visitor.

Please disregard my original comment 🤦‍♂️

Although I now wonder if visit_map should be changed to check length as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default maps in serde allows unknown keys and when unit is flattened, all keys become unknown. But you're right -- in case of #[serde(deny_unknown_fields)] we should return error if map not empty. That's idea for another PR!

Some(IgnoredAny) => Err(de::Error::invalid_length(
1 + seq.size_hint().unwrap_or(0),
&ExpectedInSeq(0),
)),
None => Ok(()),
}
}

fn visit_map<M>(self, mut access: M) -> Result<(), M::Error>
Expand Down
52 changes: 44 additions & 8 deletions serde_derive/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@ fn deserialize_struct(
_serde::de::VariantAccess::struct_variant(__variant, FIELDS, #visitor_expr)
},
StructForm::InternallyTagged(_, deserializer) => quote! {
_serde::Deserializer::deserialize_any(#deserializer, #visitor_expr)
_serde::Deserializer::deserialize_map(#deserializer, #visitor_expr)
},
StructForm::Untagged(_, deserializer) => quote! {
_serde::Deserializer::deserialize_any(#deserializer, #visitor_expr)
Expand Down Expand Up @@ -1397,19 +1397,55 @@ fn deserialize_internally_tagged_enum(
let expecting = format!("internally tagged enum {}", params.type_name());
let expecting = cattrs.expecting().unwrap_or(&expecting);

let this_type = &params.this_type;
let (de_impl_generics, de_ty_generics, ty_generics, where_clause) =
split_with_de_lifetime(params);
let delife = params.borrowed.de_lifetime();

quote_block! {
#variant_visitor

#variants_stmt

let (__tag, __content) = _serde::Deserializer::deserialize_any(
__deserializer,
_serde::__private::de::TaggedContentVisitor::<__Field>::new(#tag, #expecting))?;
let __deserializer = _serde::__private::de::ContentDeserializer::<__D::Error>::new(__content);
struct __Seed #de_impl_generics #where_clause {
tag: __Field,
marker: _serde::__private::PhantomData<#this_type #ty_generics>,
lifetime: _serde::__private::PhantomData<&#delife ()>,
}

match __tag {
#(#variant_arms)*
impl #de_impl_generics _serde::de::Deserialize<#delife> for __Seed #de_ty_generics #where_clause {
fn deserialize<__D>(__deserializer: __D) -> _serde::__private::Result<Self, __D::Error>
where
__D: _serde::de::Deserializer<#delife>,
{
_serde::__private::Result::map(
__Field::deserialize(__deserializer),
|__tag| __Seed {
tag: __tag,
marker: _serde::__private::PhantomData,
lifetime: _serde::__private::PhantomData,
}
)
}
}

impl #de_impl_generics _serde::de::DeserializeSeed<#delife> for __Seed #de_ty_generics #where_clause {
type Value = #this_type #ty_generics;

fn deserialize<__D>(self, __deserializer: __D) -> _serde::__private::Result<Self::Value, __D::Error>
where
__D: _serde::de::Deserializer<#delife>,
{
match self.tag {
#(#variant_arms)*
}
}
}

_serde::Deserializer::deserialize_map(
__deserializer,
_serde::__private::de::TaggedContentVisitor::<__Seed>::new(#tag, #expecting)
)
}
}

Expand Down Expand Up @@ -1862,7 +1898,7 @@ fn deserialize_internally_tagged_variant(
quote!((#default))
});
quote_block! {
_serde::Deserializer::deserialize_any(#deserializer, _serde::__private::de::InternallyTaggedUnitVisitor::new(#type_name, #variant_name))?;
_serde::Deserializer::deserialize_map(#deserializer, _serde::__private::de::InternallyTaggedUnitVisitor::new(#type_name, #variant_name))?;
_serde::__private::Ok(#this_value::#variant_ident #default)
}
}
Expand Down
5 changes: 3 additions & 2 deletions test_suite/tests/test_enum_internally_tagged.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,9 @@ fn newtype_map() {
Token::Seq { len: Some(2) },
Token::Str("NewtypeMap"), // tag
Token::Map { len: Some(0) },
Token::MapEnd,
Token::SeqEnd,
// Tokens that could follow, but assert_de_tokens_error does not want them
// Token::MapEnd,
// Token::SeqEnd,
],
"invalid type: sequence, expected a map",
);
Expand Down