From 48e5753e760641bceebeea4ee9d7e65f0abb0eca Mon Sep 17 00:00:00 2001 From: David Ewert <33990711+dewert99@users.noreply.github.com> Date: Wed, 7 Jun 2023 20:58:59 -0700 Subject: [PATCH] Allowed Enum variants to be individually marked as untagged (#2403) --- serde_derive/src/de.rs | 45 +++++- serde_derive/src/internals/ast.rs | 9 +- serde_derive/src/internals/attr.rs | 9 ++ serde_derive/src/ser.rs | 10 +- test_suite/tests/test_annotations.rs | 153 ++++++++++++++++++ .../partially_tagged_wrong_order.rs | 10 ++ .../partially_tagged_wrong_order.stderr | 5 + 7 files changed, 230 insertions(+), 11 deletions(-) create mode 100644 test_suite/tests/ui/enum-representation/partially_tagged_wrong_order.rs create mode 100644 test_suite/tests/ui/enum-representation/partially_tagged_wrong_order.stderr diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index 33a976d7b..0a274e567 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -1166,6 +1166,32 @@ fn deserialize_enum( params: &Parameters, variants: &[Variant], cattrs: &attr::Container, +) -> Fragment { + // The variants have already been checked (in ast.rs) that all untagged variants appear at the end + match variants + .iter() + .enumerate() + .find(|(_, var)| var.attrs.untagged()) + { + Some((variant_idx, _)) => { + let (tagged, untagged) = variants.split_at(variant_idx); + let tagged_frag = Expr(deserialize_homogeneous_enum(params, tagged, cattrs)); + let tagged_frag = |deserializer| { + Some(Expr(quote_block! { + let __deserializer = #deserializer; + #tagged_frag + })) + }; + deserialize_untagged_enum_after(params, untagged, cattrs, tagged_frag) + } + None => deserialize_homogeneous_enum(params, variants, cattrs), + } +} + +fn deserialize_homogeneous_enum( + params: &Parameters, + variants: &[Variant], + cattrs: &attr::Container, ) -> Fragment { match cattrs.tag() { attr::TagType::External => deserialize_externally_tagged_enum(params, variants, cattrs), @@ -1667,6 +1693,17 @@ fn deserialize_untagged_enum( variants: &[Variant], cattrs: &attr::Container, ) -> Fragment { + deserialize_untagged_enum_after(params, variants, cattrs, |_| None) +} + +fn deserialize_untagged_enum_after( + params: &Parameters, + variants: &[Variant], + cattrs: &attr::Container, + first_attempt: impl FnOnce(TokenStream) -> Option, +) -> Fragment { + let deserializer = + quote!(_serde::__private::de::ContentRefDeserializer::<__D::Error>::new(&__content)); let attempts = variants .iter() .filter(|variant| !variant.attrs.skip_deserializing()) @@ -1675,12 +1712,12 @@ fn deserialize_untagged_enum( params, variant, cattrs, - quote!( - _serde::__private::de::ContentRefDeserializer::<__D::Error>::new(&__content) - ), + deserializer.clone(), )) }); - + let attempts = first_attempt(deserializer.clone()) + .into_iter() + .chain(attempts); // TODO this message could be better by saving the errors from the failed // attempts. The heuristic used by TOML was to count the number of fields // processed before an error, and use the error that happened after the diff --git a/serde_derive/src/internals/ast.rs b/serde_derive/src/internals/ast.rs index 2a6950b2a..ca3dd33ad 100644 --- a/serde_derive/src/internals/ast.rs +++ b/serde_derive/src/internals/ast.rs @@ -140,6 +140,7 @@ fn enum_from_ast<'a>( variants: &'a Punctuated, container_default: &attr::Default, ) -> Vec> { + let mut seen_untagged = false; variants .iter() .map(|variant| { @@ -153,8 +154,12 @@ fn enum_from_ast<'a>( fields, original: variant, } - }) - .collect() + }).inspect(|variant| { + if !variant.attrs.untagged() && seen_untagged { + cx.error_spanned_by(&variant.ident, "all variants with the #[serde(untagged)] attribute must be placed at the end of the enum") + } + seen_untagged = variant.attrs.untagged() + }).collect() } fn struct_from_ast<'a>( diff --git a/serde_derive/src/internals/attr.rs b/serde_derive/src/internals/attr.rs index b0a7d08a2..bff82191b 100644 --- a/serde_derive/src/internals/attr.rs +++ b/serde_derive/src/internals/attr.rs @@ -740,6 +740,7 @@ pub struct Variant { serialize_with: Option, deserialize_with: Option, borrow: Option, + untagged: bool, } struct BorrowAttribute { @@ -762,6 +763,7 @@ impl Variant { let mut serialize_with = Attr::none(cx, SERIALIZE_WITH); let mut deserialize_with = Attr::none(cx, DESERIALIZE_WITH); let mut borrow = Attr::none(cx, BORROW); + let mut untagged = BoolAttr::none(cx, UNTAGGED); for attr in &variant.attrs { if attr.path() != SERDE { @@ -879,6 +881,8 @@ impl Variant { cx.error_spanned_by(variant, msg); } } + } else if meta.path == UNTAGGED { + untagged.set_true(&meta.path); } else { let path = meta.path.to_token_stream().to_string().replace(' ', ""); return Err( @@ -905,6 +909,7 @@ impl Variant { serialize_with: serialize_with.get(), deserialize_with: deserialize_with.get(), borrow: borrow.get(), + untagged: untagged.get(), } } @@ -956,6 +961,10 @@ impl Variant { pub fn deserialize_with(&self) -> Option<&syn::ExprPath> { self.deserialize_with.as_ref() } + + pub fn untagged(&self) -> bool { + self.untagged + } } /// Represents field attribute information diff --git a/serde_derive/src/ser.rs b/serde_derive/src/ser.rs index dcc1771ec..28dbbbc72 100644 --- a/serde_derive/src/ser.rs +++ b/serde_derive/src/ser.rs @@ -473,17 +473,17 @@ fn serialize_variant( } }; - let body = Match(match cattrs.tag() { - attr::TagType::External => { + let body = Match(match (cattrs.tag(), variant.attrs.untagged()) { + (attr::TagType::External, false) => { serialize_externally_tagged_variant(params, variant, variant_index, cattrs) } - attr::TagType::Internal { tag } => { + (attr::TagType::Internal { tag }, false) => { serialize_internally_tagged_variant(params, variant, cattrs, tag) } - attr::TagType::Adjacent { tag, content } => { + (attr::TagType::Adjacent { tag, content }, false) => { serialize_adjacently_tagged_variant(params, variant, cattrs, tag, content) } - attr::TagType::None => serialize_untagged_variant(params, variant, cattrs), + (attr::TagType::None, _) | (_, true) => serialize_untagged_variant(params, variant, cattrs), }); quote! { diff --git a/test_suite/tests/test_annotations.rs b/test_suite/tests/test_annotations.rs index 117cd3f4e..9dd894563 100644 --- a/test_suite/tests/test_annotations.rs +++ b/test_suite/tests/test_annotations.rs @@ -2442,6 +2442,159 @@ fn test_untagged_enum_containing_flatten() { ); } +#[test] +fn test_partially_untagged_enum() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + enum Exp { + Lambda(u32, Box), + #[serde(untagged)] + App(Box, Box), + #[serde(untagged)] + Var(u32), + } + use Exp::*; + + let data = Lambda(0, Box::new(App(Box::new(Var(0)), Box::new(Var(0))))); + assert_tokens( + &data, + &[ + Token::TupleVariant { + name: "Exp", + variant: "Lambda", + len: 2, + }, + Token::U32(0), + Token::Tuple { len: 2 }, + Token::U32(0), + Token::U32(0), + Token::TupleEnd, + Token::TupleVariantEnd, + ], + ); +} + +#[test] +fn test_partially_untagged_enum_generic() { + trait Trait { + type Assoc; + type Assoc2; + } + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + enum E where A: Trait { + A(A::Assoc), + #[serde(untagged)] + B(A::Assoc2), + } + + impl Trait for () { + type Assoc = T; + type Assoc2 = bool; + } + + type MyE = E<(), bool, u32>; + use E::*; + + assert_tokens::(&B(true), &[Token::Bool(true)]); + + assert_tokens::( + &A(5), + &[ + Token::NewtypeVariant { + name: "E", + variant: "A", + }, + Token::U32(5), + ], + ); +} + +#[test] +fn test_partially_untagged_enum_desugared() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + enum Test { + A(u32, u32), + B(u32), + #[serde(untagged)] + C(u32), + #[serde(untagged)] + D(u32, u32), + } + use Test::*; + + mod desugared { + use super::*; + #[derive(Serialize, Deserialize, PartialEq, Debug)] + pub(super) enum Test { + A(u32, u32), + B(u32), + } + } + use desugared::Test as TestTagged; + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + #[serde(untagged)] + enum TestUntagged { + Tagged(TestTagged), + C(u32), + D(u32, u32), + } + + impl From for TestUntagged { + fn from(test: Test) -> Self { + match test { + A(x, y) => TestUntagged::Tagged(TestTagged::A(x, y)), + B(x) => TestUntagged::Tagged(TestTagged::B(x)), + C(x) => TestUntagged::C(x), + D(x, y) => TestUntagged::D(x, y), + } + } + } + + fn assert_tokens_desugared(value: Test, tokens: &[Token]) { + assert_tokens(&value, tokens); + let desugared: TestUntagged = value.into(); + assert_tokens(&desugared, tokens); + } + + assert_tokens_desugared( + A(0, 1), + &[ + Token::TupleVariant { + name: "Test", + variant: "A", + len: 2, + }, + Token::U32(0), + Token::U32(1), + Token::TupleVariantEnd, + ], + ); + + assert_tokens_desugared( + B(1), + &[ + Token::NewtypeVariant { + name: "Test", + variant: "B", + }, + Token::U32(1), + ], + ); + + assert_tokens_desugared(C(2), &[Token::U32(2)]); + + assert_tokens_desugared( + D(3, 5), + &[ + Token::Tuple { len: 2 }, + Token::U32(3), + Token::U32(5), + Token::TupleEnd, + ], + ); +} + #[test] fn test_flatten_untagged_enum() { #[derive(Serialize, Deserialize, PartialEq, Debug)] diff --git a/test_suite/tests/ui/enum-representation/partially_tagged_wrong_order.rs b/test_suite/tests/ui/enum-representation/partially_tagged_wrong_order.rs new file mode 100644 index 000000000..a33398cae --- /dev/null +++ b/test_suite/tests/ui/enum-representation/partially_tagged_wrong_order.rs @@ -0,0 +1,10 @@ +use serde_derive::Serialize; + +#[derive(Serialize)] +enum E { + #[serde(untagged)] + A(u8), + B(String), +} + +fn main() {} diff --git a/test_suite/tests/ui/enum-representation/partially_tagged_wrong_order.stderr b/test_suite/tests/ui/enum-representation/partially_tagged_wrong_order.stderr new file mode 100644 index 000000000..17d3c3f93 --- /dev/null +++ b/test_suite/tests/ui/enum-representation/partially_tagged_wrong_order.stderr @@ -0,0 +1,5 @@ +error: all variants with the #[serde(untagged)] attribute must be placed at the end of the enum + --> tests/ui/enum-representation/partially_tagged_wrong_order.rs:7:5 + | +7 | B(String), + | ^