Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix max_encoded_len for Compact fields #508

Merged
merged 6 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 29 additions & 22 deletions derive/src/max_encoded_len.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

use crate::{
trait_bounds,
utils::{codec_crate_path, custom_mel_trait_bound, has_dumb_trait_bound, should_skip},
utils::{self, codec_crate_path, custom_mel_trait_bound, has_dumb_trait_bound, should_skip},
};
use quote::{quote, quote_spanned};
use syn::{parse_quote, spanned::Spanned, Data, DeriveInput, Fields, Type};
use syn::{parse_quote, spanned::Spanned, Data, DeriveInput, Field, Fields};

/// impl for `#[derive(MaxEncodedLen)]`
pub fn derive_max_encoded_len(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
Expand All @@ -43,13 +43,13 @@ pub fn derive_max_encoded_len(input: proc_macro::TokenStream) -> proc_macro::Tok
parse_quote!(#crate_path::MaxEncodedLen),
None,
has_dumb_trait_bound(&input.attrs),
&crate_path
&crate_path,
) {
return e.to_compile_error().into()
}
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

let data_expr = data_length_expr(&input.data);
let data_expr = data_length_expr(&input.data, &crate_path);

quote::quote!(
const _: () = {
Expand All @@ -64,22 +64,22 @@ pub fn derive_max_encoded_len(input: proc_macro::TokenStream) -> proc_macro::Tok
}

/// generate an expression to sum up the max encoded length from several fields
fn fields_length_expr(fields: &Fields) -> proc_macro2::TokenStream {
let type_iter: Box<dyn Iterator<Item = &Type>> = match fields {
Fields::Named(ref fields) => Box::new(
fields.named.iter().filter_map(|field| if should_skip(&field.attrs) {
fn fields_length_expr(fields: &Fields, crate_path: &syn::Path) -> proc_macro2::TokenStream {
let fields_iter: Box<dyn Iterator<Item = &Field>> = match fields {
Fields::Named(ref fields) => Box::new(fields.named.iter().filter_map(|field| {
if should_skip(&field.attrs) {
None
} else {
Some(&field.ty)
})
),
Fields::Unnamed(ref fields) => Box::new(
fields.unnamed.iter().filter_map(|field| if should_skip(&field.attrs) {
Some(field)
}
})),
Fields::Unnamed(ref fields) => Box::new(fields.unnamed.iter().filter_map(|field| {
if should_skip(&field.attrs) {
None
} else {
Some(&field.ty)
})
),
Some(field)
}
})),
Fields::Unit => Box::new(std::iter::empty()),
};
// expands to an expression like
Expand All @@ -92,9 +92,16 @@ fn fields_length_expr(fields: &Fields) -> proc_macro2::TokenStream {
// `max_encoded_len` call. This way, if one field's type doesn't implement
// `MaxEncodedLen`, the compiler's error message will underline which field
// caused the issue.
let expansion = type_iter.map(|ty| {
quote_spanned! {
ty.span() => .saturating_add(<#ty>::max_encoded_len())
let expansion = fields_iter.map(|field| {
let ty = &field.ty;
if utils::is_compact(&field) {
quote_spanned! {
ty.span() => .saturating_add(<#crate_path::Compact::<#ty> as #crate_path::MaxEncodedLen>::max_encoded_len())
}
} else {
quote_spanned! {
ty.span() => .saturating_add(<#ty as #crate_path::MaxEncodedLen>::max_encoded_len())
}
}
});
quote! {
Expand All @@ -103,9 +110,9 @@ fn fields_length_expr(fields: &Fields) -> proc_macro2::TokenStream {
}

// generate an expression to sum up the max encoded length of each field
fn data_length_expr(data: &Data) -> proc_macro2::TokenStream {
fn data_length_expr(data: &Data, crate_path: &syn::Path) -> proc_macro2::TokenStream {
match *data {
Data::Struct(ref data) => fields_length_expr(&data.fields),
Data::Struct(ref data) => fields_length_expr(&data.fields, crate_path),
Data::Enum(ref data) => {
// We need an expression expanded for each variant like
//
Expand All @@ -121,7 +128,7 @@ fn data_length_expr(data: &Data) -> proc_macro2::TokenStream {
// Each variant expression's sum is computed the way an equivalent struct's would be.

let expansion = data.variants.iter().map(|variant| {
let variant_expression = fields_length_expr(&variant.fields);
let variant_expression = fields_length_expr(&variant.fields, crate_path);
quote! {
.max(#variant_expression)
}
Expand Down
25 changes: 24 additions & 1 deletion tests/max_encoded_len.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
//! Tests for MaxEncodedLen derive macro
#![cfg(all(feature = "derive", feature = "max-encoded-len"))]

use parity_scale_codec::{MaxEncodedLen, Compact, Decode, Encode};
use parity_scale_codec::{Compact, Decode, Encode, MaxEncodedLen};

#[derive(Encode, MaxEncodedLen)]
struct Primitives {
Expand Down Expand Up @@ -64,6 +64,29 @@ fn generic_max_length() {
assert_eq!(Generic::<u32>::max_encoded_len(), u32::max_encoded_len() * 2);
}

#[derive(Encode, MaxEncodedLen)]
struct CompactField {
#[codec(compact)]
t: u64,
v: u64,
}
pgherveou marked this conversation as resolved.
Show resolved Hide resolved

#[test]
fn compact_field_max_length() {
assert_eq!(
CompactField::max_encoded_len(),
Compact::<u64>::max_encoded_len() + u64::max_encoded_len()
);
}

#[derive(Encode, MaxEncodedLen)]
struct CompactStruct(#[codec(compact)] u64);

#[test]
fn compact_struct_max_length() {
assert_eq!(CompactStruct::max_encoded_len(), Compact::<u64>::max_encoded_len());
}

#[derive(Encode, MaxEncodedLen)]
struct TwoGenerics<T, U> {
t: T,
Expand Down
20 changes: 12 additions & 8 deletions tests/max_encoded_len_ui/unsupported_variant.stderr
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
error[E0599]: no function or associated item named `max_encoded_len` found for struct `NotMel` in the current scope
error[E0277]: the trait bound `NotMel: MaxEncodedLen` is not satisfied
--> tests/max_encoded_len_ui/unsupported_variant.rs:8:9
|
4 | struct NotMel;
| ------------- function or associated item `max_encoded_len` not found for this struct
...
8 | NotMel(NotMel),
| ^^^^^^ function or associated item not found in `NotMel`
| ^^^^^^ the trait `MaxEncodedLen` is not implemented for `NotMel`
|
= help: items from traits can only be used if the trait is implemented and in scope
= note: the following trait defines an item `max_encoded_len`, perhaps you need to implement it:
candidate #1: `MaxEncodedLen`
= help: the following other types implement trait `MaxEncodedLen`:
()
(TupleElement0, TupleElement1)
(TupleElement0, TupleElement1, TupleElement2)
(TupleElement0, TupleElement1, TupleElement2, TupleElement3)
(TupleElement0, TupleElement1, TupleElement2, TupleElement3, TupleElement4)
(TupleElement0, TupleElement1, TupleElement2, TupleElement3, TupleElement4, TupleElement5)
(TupleElement0, TupleElement1, TupleElement2, TupleElement3, TupleElement4, TupleElement5, TupleElement6)
(TupleElement0, TupleElement1, TupleElement2, TupleElement3, TupleElement4, TupleElement5, TupleElement6, TupleElement7)
and $N others