Skip to content

Commit

Permalink
Allowed Enum variants to be individually marked as untagged (#2403)
Browse files Browse the repository at this point in the history
  • Loading branch information
dewert99 authored Jun 8, 2023
1 parent bbba632 commit 48e5753
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 11 deletions.
45 changes: 41 additions & 4 deletions serde_derive/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,32 @@ fn deserialize_enum(
params: &Parameters,
variants: &[Variant],
cattrs: &attr::Container,
) -> Fragment {
// The variants have already been checked (in ast.rs) that all untagged variants appear at the end
match variants
.iter()
.enumerate()
.find(|(_, var)| var.attrs.untagged())
{
Some((variant_idx, _)) => {
let (tagged, untagged) = variants.split_at(variant_idx);
let tagged_frag = Expr(deserialize_homogeneous_enum(params, tagged, cattrs));
let tagged_frag = |deserializer| {
Some(Expr(quote_block! {
let __deserializer = #deserializer;
#tagged_frag
}))
};
deserialize_untagged_enum_after(params, untagged, cattrs, tagged_frag)
}
None => deserialize_homogeneous_enum(params, variants, cattrs),
}
}

fn deserialize_homogeneous_enum(
params: &Parameters,
variants: &[Variant],
cattrs: &attr::Container,
) -> Fragment {
match cattrs.tag() {
attr::TagType::External => deserialize_externally_tagged_enum(params, variants, cattrs),
Expand Down Expand Up @@ -1667,6 +1693,17 @@ fn deserialize_untagged_enum(
variants: &[Variant],
cattrs: &attr::Container,
) -> Fragment {
deserialize_untagged_enum_after(params, variants, cattrs, |_| None)
}

fn deserialize_untagged_enum_after(
params: &Parameters,
variants: &[Variant],
cattrs: &attr::Container,
first_attempt: impl FnOnce(TokenStream) -> Option<Expr>,
) -> Fragment {
let deserializer =
quote!(_serde::__private::de::ContentRefDeserializer::<__D::Error>::new(&__content));
let attempts = variants
.iter()
.filter(|variant| !variant.attrs.skip_deserializing())
Expand All @@ -1675,12 +1712,12 @@ fn deserialize_untagged_enum(
params,
variant,
cattrs,
quote!(
_serde::__private::de::ContentRefDeserializer::<__D::Error>::new(&__content)
),
deserializer.clone(),
))
});

let attempts = first_attempt(deserializer.clone())
.into_iter()
.chain(attempts);
// TODO this message could be better by saving the errors from the failed
// attempts. The heuristic used by TOML was to count the number of fields
// processed before an error, and use the error that happened after the
Expand Down
9 changes: 7 additions & 2 deletions serde_derive/src/internals/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ fn enum_from_ast<'a>(
variants: &'a Punctuated<syn::Variant, Token![,]>,
container_default: &attr::Default,
) -> Vec<Variant<'a>> {
let mut seen_untagged = false;
variants
.iter()
.map(|variant| {
Expand All @@ -153,8 +154,12 @@ fn enum_from_ast<'a>(
fields,
original: variant,
}
})
.collect()
}).inspect(|variant| {
if !variant.attrs.untagged() && seen_untagged {
cx.error_spanned_by(&variant.ident, "all variants with the #[serde(untagged)] attribute must be placed at the end of the enum")
}
seen_untagged = variant.attrs.untagged()
}).collect()
}

fn struct_from_ast<'a>(
Expand Down
9 changes: 9 additions & 0 deletions serde_derive/src/internals/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,7 @@ pub struct Variant {
serialize_with: Option<syn::ExprPath>,
deserialize_with: Option<syn::ExprPath>,
borrow: Option<BorrowAttribute>,
untagged: bool,
}

struct BorrowAttribute {
Expand All @@ -762,6 +763,7 @@ impl Variant {
let mut serialize_with = Attr::none(cx, SERIALIZE_WITH);
let mut deserialize_with = Attr::none(cx, DESERIALIZE_WITH);
let mut borrow = Attr::none(cx, BORROW);
let mut untagged = BoolAttr::none(cx, UNTAGGED);

for attr in &variant.attrs {
if attr.path() != SERDE {
Expand Down Expand Up @@ -879,6 +881,8 @@ impl Variant {
cx.error_spanned_by(variant, msg);
}
}
} else if meta.path == UNTAGGED {
untagged.set_true(&meta.path);
} else {
let path = meta.path.to_token_stream().to_string().replace(' ', "");
return Err(
Expand All @@ -905,6 +909,7 @@ impl Variant {
serialize_with: serialize_with.get(),
deserialize_with: deserialize_with.get(),
borrow: borrow.get(),
untagged: untagged.get(),
}
}

Expand Down Expand Up @@ -956,6 +961,10 @@ impl Variant {
pub fn deserialize_with(&self) -> Option<&syn::ExprPath> {
self.deserialize_with.as_ref()
}

pub fn untagged(&self) -> bool {
self.untagged
}
}

/// Represents field attribute information
Expand Down
10 changes: 5 additions & 5 deletions serde_derive/src/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,17 +473,17 @@ fn serialize_variant(
}
};

let body = Match(match cattrs.tag() {
attr::TagType::External => {
let body = Match(match (cattrs.tag(), variant.attrs.untagged()) {
(attr::TagType::External, false) => {
serialize_externally_tagged_variant(params, variant, variant_index, cattrs)
}
attr::TagType::Internal { tag } => {
(attr::TagType::Internal { tag }, false) => {
serialize_internally_tagged_variant(params, variant, cattrs, tag)
}
attr::TagType::Adjacent { tag, content } => {
(attr::TagType::Adjacent { tag, content }, false) => {
serialize_adjacently_tagged_variant(params, variant, cattrs, tag, content)
}
attr::TagType::None => serialize_untagged_variant(params, variant, cattrs),
(attr::TagType::None, _) | (_, true) => serialize_untagged_variant(params, variant, cattrs),
});

quote! {
Expand Down
153 changes: 153 additions & 0 deletions test_suite/tests/test_annotations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2442,6 +2442,159 @@ fn test_untagged_enum_containing_flatten() {
);
}

#[test]
fn test_partially_untagged_enum() {
#[derive(Serialize, Deserialize, PartialEq, Debug)]
enum Exp {
Lambda(u32, Box<Exp>),
#[serde(untagged)]
App(Box<Exp>, Box<Exp>),
#[serde(untagged)]
Var(u32),
}
use Exp::*;

let data = Lambda(0, Box::new(App(Box::new(Var(0)), Box::new(Var(0)))));
assert_tokens(
&data,
&[
Token::TupleVariant {
name: "Exp",
variant: "Lambda",
len: 2,
},
Token::U32(0),
Token::Tuple { len: 2 },
Token::U32(0),
Token::U32(0),
Token::TupleEnd,
Token::TupleVariantEnd,
],
);
}

#[test]
fn test_partially_untagged_enum_generic() {
trait Trait<T> {
type Assoc;
type Assoc2;
}

#[derive(Serialize, Deserialize, PartialEq, Debug)]
enum E<A, B, C> where A: Trait<C, Assoc2=B> {
A(A::Assoc),
#[serde(untagged)]
B(A::Assoc2),
}

impl<T> Trait<T> for () {
type Assoc = T;
type Assoc2 = bool;
}

type MyE = E<(), bool, u32>;
use E::*;

assert_tokens::<MyE>(&B(true), &[Token::Bool(true)]);

assert_tokens::<MyE>(
&A(5),
&[
Token::NewtypeVariant {
name: "E",
variant: "A",
},
Token::U32(5),
],
);
}

#[test]
fn test_partially_untagged_enum_desugared() {
#[derive(Serialize, Deserialize, PartialEq, Debug)]
enum Test {
A(u32, u32),
B(u32),
#[serde(untagged)]
C(u32),
#[serde(untagged)]
D(u32, u32),
}
use Test::*;

mod desugared {
use super::*;
#[derive(Serialize, Deserialize, PartialEq, Debug)]
pub(super) enum Test {
A(u32, u32),
B(u32),
}
}
use desugared::Test as TestTagged;

#[derive(Serialize, Deserialize, PartialEq, Debug)]
#[serde(untagged)]
enum TestUntagged {
Tagged(TestTagged),
C(u32),
D(u32, u32),
}

impl From<Test> for TestUntagged {
fn from(test: Test) -> Self {
match test {
A(x, y) => TestUntagged::Tagged(TestTagged::A(x, y)),
B(x) => TestUntagged::Tagged(TestTagged::B(x)),
C(x) => TestUntagged::C(x),
D(x, y) => TestUntagged::D(x, y),
}
}
}

fn assert_tokens_desugared(value: Test, tokens: &[Token]) {
assert_tokens(&value, tokens);
let desugared: TestUntagged = value.into();
assert_tokens(&desugared, tokens);
}

assert_tokens_desugared(
A(0, 1),
&[
Token::TupleVariant {
name: "Test",
variant: "A",
len: 2,
},
Token::U32(0),
Token::U32(1),
Token::TupleVariantEnd,
],
);

assert_tokens_desugared(
B(1),
&[
Token::NewtypeVariant {
name: "Test",
variant: "B",
},
Token::U32(1),
],
);

assert_tokens_desugared(C(2), &[Token::U32(2)]);

assert_tokens_desugared(
D(3, 5),
&[
Token::Tuple { len: 2 },
Token::U32(3),
Token::U32(5),
Token::TupleEnd,
],
);
}

#[test]
fn test_flatten_untagged_enum() {
#[derive(Serialize, Deserialize, PartialEq, Debug)]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use serde_derive::Serialize;

#[derive(Serialize)]
enum E {
#[serde(untagged)]
A(u8),
B(String),
}

fn main() {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
error: all variants with the #[serde(untagged)] attribute must be placed at the end of the enum
--> tests/ui/enum-representation/partially_tagged_wrong_order.rs:7:5
|
7 | B(String),
| ^

0 comments on commit 48e5753

Please sign in to comment.