Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: optimize BorshSerialize derive for enums with unit variants #262

Merged
merged 5 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 107 additions & 37 deletions borsh-derive/src/internals/serialize/enums/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result<TokenStream2> {
let mut fields_body = TokenStream2::new();
let use_discriminant = item::contains_use_discriminant(input)?;
let discriminants = Discriminants::new(&input.variants);
let mut has_unit_variant = false;

for (variant_idx, variant) in input.variants.iter().enumerate() {
let variant_ident = &variant.ident;
Expand All @@ -30,13 +31,16 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result<TokenStream2> {
&mut generics_output,
)?;
all_variants_idx_body.extend(variant_output.variant_idx_body);
let (variant_header, variant_body) = (variant_output.header, variant_output.body);
fields_body.extend(quote!(
#enum_ident::#variant_ident #variant_header => {
#variant_body
}
))
match variant_output.body {
VariantBody::Unit => has_unit_variant = true,
VariantBody::Fields(VariantFields { header, body }) => fields_body.extend(quote!(
#enum_ident::#variant_ident #header => {
#body
}
)),
}
}
let fields_body = optimize_fields_body(fields_body, has_unit_variant);
generics_output.extend(&mut where_clause, &cratename);

Ok(quote! {
Expand All @@ -47,31 +51,78 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result<TokenStream2> {
};
writer.write_all(&variant_idx.to_le_bytes())?;

match self {
#fields_body
}
#fields_body
Ok(())
}
}
})
}

struct VariantOutput {
fn optimize_fields_body(fields_body: TokenStream2, has_unit_variant: bool) -> TokenStream2 {
if fields_body.is_empty() {
// If we no variants with fields, there's nothing to match against. Just
// re-use the empty token stream.
fields_body
} else {
let unit_fields_catchall = if has_unit_variant {
// We had some variants with unit fields, create a catch-all for
// these to be used at the bottom.
quote!(
_ => {}
)
} else {
TokenStream2::new()
};
// Create a match that serialises all the fields for each non-unit
// variant and add a catch-all at the bottom if we do have unit
// variants.
quote!(
match self {
#fields_body
#unit_fields_catchall
}
)
}
}

#[derive(Default)]
struct VariantFields {
header: TokenStream2,
body: TokenStream2,
variant_idx_body: TokenStream2,
}

impl VariantOutput {
fn new() -> Self {
Self {
body: TokenStream2::new(),
header: TokenStream2::new(),
variant_idx_body: TokenStream2::new(),
impl VariantFields {
fn named_header(self) -> Self {
let header = self.header;

VariantFields {
// `..` pattern matching works even if all fields were specified
header: quote! { { #header.. }},
body: self.body,
}
}
fn unnamed_header(self) -> Self {
let header = self.header;

VariantFields {
header: quote! { ( #header )},
body: self.body,
}
}
}

enum VariantBody {
// No body variant, unit enum variant.
Unit,
// Variant with body (fields)
Fields(VariantFields),
}

struct VariantOutput {
body: VariantBody,
variant_idx_body: TokenStream2,
}

fn process_variant(
variant: &Variant,
enum_ident: &Ident,
Expand All @@ -80,36 +131,39 @@ fn process_variant(
generics: &mut serialize::GenericsOutput,
) -> syn::Result<VariantOutput> {
let variant_ident = &variant.ident;
let mut variant_output = VariantOutput::new();
match &variant.fields {
let variant_output = match &variant.fields {
Fields::Named(fields) => {
let mut variant_fields = VariantFields::default();
for field in &fields.named {
let field_id = serialize::FieldId::Enum(field.ident.clone().unwrap());
process_field(field, field_id, cratename, generics, &mut variant_output)?;
process_field(field, field_id, cratename, generics, &mut variant_fields)?;
}
VariantOutput {
body: VariantBody::Fields(variant_fields.named_header()),
variant_idx_body: quote!(
#enum_ident::#variant_ident {..} => #discriminant_value,
),
}
let header = variant_output.header;
// `..` pattern matching works even if all fields were specified
variant_output.header = quote! { { #header.. }};
variant_output.variant_idx_body = quote!(
#enum_ident::#variant_ident {..} => #discriminant_value,
);
}
Fields::Unnamed(fields) => {
let mut variant_fields = VariantFields::default();
for (field_idx, field) in fields.unnamed.iter().enumerate() {
let field_id = serialize::FieldId::new_enum_unnamed(field_idx)?;
process_field(field, field_id, cratename, generics, &mut variant_output)?;
process_field(field, field_id, cratename, generics, &mut variant_fields)?;
}
VariantOutput {
body: VariantBody::Fields(variant_fields.unnamed_header()),
variant_idx_body: quote!(
#enum_ident::#variant_ident(..) => #discriminant_value,
),
}
let header = variant_output.header;
variant_output.header = quote! { ( #header )};
variant_output.variant_idx_body = quote!(
#enum_ident::#variant_ident(..) => #discriminant_value,
);
}
Fields::Unit => {
variant_output.variant_idx_body = quote!(
Fields::Unit => VariantOutput {
body: VariantBody::Unit,
variant_idx_body: quote!(
#enum_ident::#variant_ident => #discriminant_value,
);
}
),
},
};
Ok(variant_output)
}
Expand All @@ -119,7 +173,7 @@ fn process_field(
field_id: serialize::FieldId,
cratename: &Path,
generics: &mut serialize::GenericsOutput,
output: &mut VariantOutput,
output: &mut VariantFields,
) -> syn::Result<()> {
let parsed = field::Attributes::parse(&field.attrs)?;

Expand Down Expand Up @@ -425,4 +479,20 @@ mod tests {

local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
}

#[test]
fn mixed_with_unit_variants() {
let item_enum: ItemEnum = syn::parse2(quote! {
enum X {
A(u16),
B,
C {x: i32, y: i32},
D,
}
})
.unwrap();
let actual = process(&item_enum, default_cratename()).unwrap();

local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,6 @@ impl borsh::ser::BorshSerialize for X {
X::F => 5u8,
};
writer.write_all(&variant_idx.to_le_bytes())?;
match self {
X::A => {}
X::B => {}
X::C => {}
X::D => {}
X::E => {}
X::F => {}
}
Ok(())
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,6 @@ impl borsh::ser::BorshSerialize for X {
X::F => 10 + 1,
};
writer.write_all(&variant_idx.to_le_bytes())?;
match self {
X::A => {}
X::B => {}
X::C => {}
X::D => {}
X::E => {}
X::F => {}
}
Ok(())
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
---
source: borsh-derive/src/internals/serialize/enums/mod.rs
expression: pretty_print_syn_str(&actual).unwrap()
---
impl borsh::ser::BorshSerialize for X {
fn serialize<W: borsh::io::Write>(
&self,
writer: &mut W,
) -> ::core::result::Result<(), borsh::io::Error> {
let variant_idx: u8 = match self {
X::A(..) => 0u8,
X::B => 1u8,
X::C { .. } => 2u8,
X::D => 3u8,
};
writer.write_all(&variant_idx.to_le_bytes())?;
match self {
X::A(id0) => {
borsh::BorshSerialize::serialize(id0, writer)?;
}
X::C { x, y, .. } => {
borsh::BorshSerialize::serialize(x, writer)?;
borsh::BorshSerialize::serialize(y, writer)?;
}
_ => {}
}
Ok(())
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
source: borsh/tests/test_simple_structs.rs
expression: encoded
---
[
1,
]
15 changes: 15 additions & 0 deletions borsh/tests/snapshots/test_simple_structs__mixed_enum-3.snap
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
---
source: borsh/tests/test_simple_structs.rs
expression: encoded
---
[
2,
132,
0,
0,
0,
239,
255,
255,
255,
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
source: borsh/tests/test_simple_structs.rs
expression: encoded
---
[
3,
]
9 changes: 9 additions & 0 deletions borsh/tests/snapshots/test_simple_structs__mixed_enum.snap
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
source: borsh/tests/test_simple_structs.rs
expression: encoded
---
[
0,
13,
0,
]
27 changes: 27 additions & 0 deletions borsh/tests/test_simple_structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,30 @@ fn test_object_length() {

assert_eq!(encoded_a_len, len_helper_result);
}

#[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)]
enum MixedWithUnitVariants {
A(u16),
B,
C { x: i32, y: i32 },
D,
}

#[test]
fn test_mixed_enum() {
let vars = vec![
MixedWithUnitVariants::A(13),
MixedWithUnitVariants::B,
MixedWithUnitVariants::C { x: 132, y: -17 },
MixedWithUnitVariants::D,
];
for variant in vars {
let encoded = to_vec(&variant).unwrap();
#[cfg(feature = "std")]
insta::assert_debug_snapshot!(encoded);

let decoded = from_slice::<MixedWithUnitVariants>(&encoded).unwrap();

assert_eq!(variant, decoded);
}
}
Loading