Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

has_flatten rework #2795

Merged
merged 5 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 17 additions & 25 deletions serde_derive/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,21 +281,11 @@ fn deserialize_body(cont: &Container, params: &Parameters) -> Fragment {
} else if let attr::Identifier::No = cont.attrs.identifier() {
match &cont.data {
Data::Enum(variants) => deserialize_enum(params, variants, &cont.attrs),
Data::Struct(Style::Struct, fields) => deserialize_struct(
params,
fields,
&cont.attrs,
cont.attrs.has_flatten(),
StructForm::Struct,
),
Data::Struct(Style::Struct, fields) => {
deserialize_struct(params, fields, &cont.attrs, StructForm::Struct)
}
Data::Struct(Style::Tuple, fields) | Data::Struct(Style::Newtype, fields) => {
deserialize_tuple(
params,
fields,
&cont.attrs,
cont.attrs.has_flatten(),
TupleForm::Tuple,
)
deserialize_tuple(params, fields, &cont.attrs, TupleForm::Tuple)
}
Data::Struct(Style::Unit, _) => deserialize_unit_struct(params, &cont.attrs),
}
Expand Down Expand Up @@ -469,11 +459,10 @@ fn deserialize_tuple(
params: &Parameters,
fields: &[Field],
cattrs: &attr::Container,
has_flatten: bool,
form: TupleForm,
) -> Fragment {
assert!(
!has_flatten,
!has_flatten(fields),
"tuples and tuple variants cannot have flatten fields"
);

Expand Down Expand Up @@ -594,7 +583,7 @@ fn deserialize_tuple_in_place(
cattrs: &attr::Container,
) -> Fragment {
assert!(
!cattrs.has_flatten(),
!has_flatten(fields),
"tuples and tuple variants cannot have flatten fields"
);

Expand Down Expand Up @@ -927,7 +916,6 @@ fn deserialize_struct(
params: &Parameters,
fields: &[Field],
cattrs: &attr::Container,
has_flatten: bool,
form: StructForm,
) -> Fragment {
let this_type = &params.this_type;
Expand Down Expand Up @@ -976,6 +964,8 @@ fn deserialize_struct(
)
})
.collect();

let has_flatten = has_flatten(fields);
let field_visitor = deserialize_field_identifier(&field_names_idents, cattrs, has_flatten);

// untagged struct variants do not get a visit_seq method. The same applies to
Expand Down Expand Up @@ -1115,7 +1105,7 @@ fn deserialize_struct_in_place(
) -> Option<Fragment> {
// for now we do not support in_place deserialization for structs that
// are represented as map.
if cattrs.has_flatten() {
if has_flatten(fields) {
return None;
}

Expand Down Expand Up @@ -1831,14 +1821,12 @@ fn deserialize_externally_tagged_variant(
params,
&variant.fields,
cattrs,
variant.attrs.has_flatten(),
TupleForm::ExternallyTagged(variant_ident),
),
Style::Struct => deserialize_struct(
params,
&variant.fields,
cattrs,
variant.attrs.has_flatten(),
StructForm::ExternallyTagged(variant_ident),
),
}
Expand Down Expand Up @@ -1882,7 +1870,6 @@ fn deserialize_internally_tagged_variant(
params,
&variant.fields,
cattrs,
variant.attrs.has_flatten(),
StructForm::InternallyTagged(variant_ident, deserializer),
),
Style::Tuple => unreachable!("checked in serde_derive_internals"),
Expand Down Expand Up @@ -1933,14 +1920,12 @@ fn deserialize_untagged_variant(
params,
&variant.fields,
cattrs,
variant.attrs.has_flatten(),
TupleForm::Untagged(variant_ident, deserializer),
),
Style::Struct => deserialize_struct(
params,
&variant.fields,
cattrs,
variant.attrs.has_flatten(),
StructForm::Untagged(variant_ident, deserializer),
),
}
Expand Down Expand Up @@ -2707,7 +2692,7 @@ fn deserialize_map_in_place(
cattrs: &attr::Container,
) -> Fragment {
assert!(
!cattrs.has_flatten(),
!has_flatten(fields),
"inplace deserialization of maps does not support flatten fields"
);

Expand Down Expand Up @@ -3042,6 +3027,13 @@ fn effective_style(variant: &Variant) -> Style {
}
}

/// True if there are fields that is not skipped and has a `#[serde(flatten)]` attribute.
fn has_flatten(fields: &[Field]) -> bool {
fields
.iter()
.any(|field| field.attrs.flatten() && !field.attrs.skip_deserializing())
}

struct DeImplGenerics<'a>(&'a Parameters);
#[cfg(feature = "deserialize_in_place")]
struct InPlaceImplGenerics<'a>(&'a Parameters);
Expand Down
14 changes: 1 addition & 13 deletions serde_derive/src/internals/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ impl<'a> Container<'a> {
item: &'a syn::DeriveInput,
derive: Derive,
) -> Option<Container<'a>> {
let mut attrs = attr::Container::from_ast(cx, item);
let attrs = attr::Container::from_ast(cx, item);

let mut data = match &item.data {
syn::Data::Enum(data) => Data::Enum(enum_from_ast(cx, &data.variants, attrs.default())),
Expand All @@ -77,16 +77,11 @@ impl<'a> Container<'a> {
}
};

let mut has_flatten = false;
match &mut data {
Data::Enum(variants) => {
for variant in variants {
variant.attrs.rename_by_rules(attrs.rename_all_rules());
for field in &mut variant.fields {
if field.attrs.flatten() {
has_flatten = true;
variant.attrs.mark_has_flatten();
}
field.attrs.rename_by_rules(
variant
.attrs
Expand All @@ -98,18 +93,11 @@ impl<'a> Container<'a> {
}
Data::Struct(_, fields) => {
for field in fields {
if field.attrs.flatten() {
has_flatten = true;
}
field.attrs.rename_by_rules(attrs.rename_all_rules());
}
}
}

if has_flatten {
attrs.mark_has_flatten();
}

let mut item = Container {
ident: item.ident.clone(),
attrs,
Expand Down
47 changes: 0 additions & 47 deletions serde_derive/src/internals/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,23 +216,6 @@ pub struct Container {
type_into: Option<syn::Type>,
remote: Option<syn::Path>,
identifier: Identifier,
/// True if container is a struct and has a field with `#[serde(flatten)]`,
/// or is an enum with a struct variant which has a field with
/// `#[serde(flatten)]`.
///
/// ```ignore
/// struct Container {
/// #[serde(flatten)]
/// some_field: (),
/// }
/// enum Container {
/// Variant {
/// #[serde(flatten)]
/// some_field: (),
/// },
/// }
/// ```
has_flatten: bool,
serde_path: Option<syn::Path>,
is_packed: bool,
/// Error message generated when type can't be deserialized
Expand Down Expand Up @@ -603,7 +586,6 @@ impl Container {
type_into: type_into.get(),
remote: remote.get(),
identifier: decide_identifier(cx, item, field_identifier, variant_identifier),
has_flatten: false,
serde_path: serde_path.get(),
is_packed,
expecting: expecting.get(),
Expand Down Expand Up @@ -671,14 +653,6 @@ impl Container {
self.identifier
}

pub fn has_flatten(&self) -> bool {
self.has_flatten
}

pub fn mark_has_flatten(&mut self) {
self.has_flatten = true;
}

pub fn custom_serde_path(&self) -> Option<&syn::Path> {
self.serde_path.as_ref()
}
Expand Down Expand Up @@ -810,18 +784,6 @@ pub struct Variant {
rename_all_rules: RenameAllRules,
ser_bound: Option<Vec<syn::WherePredicate>>,
de_bound: Option<Vec<syn::WherePredicate>>,
/// True if variant is a struct variant which contains a field with
/// `#[serde(flatten)]`.
///
/// ```ignore
/// enum Enum {
/// Variant {
/// #[serde(flatten)]
/// some_field: (),
/// },
/// }
/// ```
has_flatten: bool,
skip_deserializing: bool,
skip_serializing: bool,
other: bool,
Expand Down Expand Up @@ -991,7 +953,6 @@ impl Variant {
},
ser_bound: ser_bound.get(),
de_bound: de_bound.get(),
has_flatten: false,
skip_deserializing: skip_deserializing.get(),
skip_serializing: skip_serializing.get(),
other: other.get(),
Expand Down Expand Up @@ -1034,14 +995,6 @@ impl Variant {
self.de_bound.as_ref().map(|vec| &vec[..])
}

pub fn has_flatten(&self) -> bool {
self.has_flatten
}

pub fn mark_has_flatten(&mut self) {
self.has_flatten = true;
}

pub fn skip_deserializing(&self) -> bool {
self.skip_deserializing
}
Expand Down
33 changes: 12 additions & 21 deletions serde_derive/src/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,18 @@ fn serialize_tuple_struct(
}

fn serialize_struct(params: &Parameters, fields: &[Field], cattrs: &attr::Container) -> Fragment {
assert!(fields.len() as u64 <= u64::from(u32::MAX));
assert!(
fields.len() as u64 <= u64::from(u32::MAX),
"too many fields in {}: {}, maximum supported count is {}",
cattrs.name().serialize_name(),
fields.len(),
u32::MAX
);

if cattrs.has_flatten() {
let has_non_skipped_flatten = fields
.iter()
.any(|field| field.attrs.flatten() && !field.attrs.skip_serializing());
if has_non_skipped_flatten {
serialize_struct_as_map(params, fields, cattrs)
} else {
serialize_struct_as_struct(params, fields, cattrs)
Expand Down Expand Up @@ -370,26 +379,8 @@ fn serialize_struct_as_map(

let let_mut = mut_if(serialized_fields.peek().is_some() || tag_field_exists);

let len = if cattrs.has_flatten() {
quote!(_serde::__private::None)
} else {
let len = serialized_fields
.map(|field| match field.attrs.skip_serializing_if() {
None => quote!(1),
Some(path) => {
let field_expr = get_member(params, field, &field.member);
quote!(if #path(#field_expr) { 0 } else { 1 })
}
})
.fold(
quote!(#tag_field_exists as usize),
|sum, expr| quote!(#sum + #expr),
);
quote!(_serde::__private::Some(#len))
};

quote_block! {
let #let_mut __serde_state = _serde::Serializer::serialize_map(__serializer, #len)?;
let #let_mut __serde_state = _serde::Serializer::serialize_map(__serializer, _serde::__private::None)?;
#tag_field
#(#serialize_fields)*
_serde::ser::SerializeMap::end(__serde_state)
Expand Down
32 changes: 24 additions & 8 deletions test_suite/tests/test_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -547,13 +547,32 @@ fn test_gen() {
}
assert::<FlattenWith>();

#[derive(Serialize, Deserialize)]
pub struct Flatten<T> {
#[serde(flatten)]
t: T,
}

#[derive(Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct FlattenDenyUnknown<T> {
#[serde(flatten)]
t: T,
}

#[derive(Serialize, Deserialize)]
pub struct SkipDeserializing<T> {
#[serde(skip_deserializing)]
flat: T,
}

#[derive(Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct SkipDeserializingDenyUnknown<T> {
#[serde(skip_deserializing)]
flat: T,
}

#[derive(Serialize, Deserialize)]
pub struct StaticStrStruct<'a> {
a: &'a str,
Expand Down Expand Up @@ -720,14 +739,11 @@ fn test_gen() {
flat: StdOption<T>,
}

#[allow(clippy::collection_is_never_read)] // FIXME
const _: () = {
#[derive(Serialize, Deserialize)]
pub struct FlattenSkipDeserializing<T> {
#[serde(flatten, skip_deserializing)]
flat: T,
}
};
#[derive(Serialize, Deserialize)]
pub struct FlattenSkipDeserializing<T> {
#[serde(flatten, skip_deserializing)]
flat: T,
}

#[derive(Serialize, Deserialize)]
#[serde(untagged)]
Expand Down