Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimised the Decode::decode for [T; N] #299

Merged
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ serde_derive = { version = "1.0" }
parity-scale-codec-derive = { path = "derive", default-features = false }
quickcheck = "1.0"
trybuild = "1.0.42"
paste = "1"
bkchr marked this conversation as resolved.
Show resolved Hide resolved

[[bench]]
name = "benches"
Expand Down
122 changes: 111 additions & 11 deletions src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@ use core::{
iter::FromIterator,
marker::PhantomData,
mem,
mem::{
MaybeUninit,
forget,
},
ops::{Deref, Range, RangeInclusive},
time::Duration,
ptr,
};
use core::num::{
NonZeroI8,
Expand All @@ -35,7 +40,6 @@ use core::num::{
NonZeroU64,
NonZeroU128,
};
use arrayvec::ArrayVec;

use byte_slice_cast::{AsByteSlice, AsMutByteSlice, ToMutByteSlice};

Expand Down Expand Up @@ -637,7 +641,91 @@ pub(crate) fn encode_slice_no_len<T: Encode, W: Output + ?Sized>(slice: &[T], de
}
}

/// Decode the slice (without prepended the len).
/// Decode the array.
///
/// This is equivalent to decoding all the element one by one, but it is optimized for some types.
#[inline]
pub(crate) fn decode_array<I: Input, T: Decode, const N: usize>(input: &mut I) -> Result<[T; N], Error> {
#[inline]
fn general_array_decode<I: Input, T: Decode, const N: usize>(input: &mut I) -> Result<[T; N], Error> {
let mut uninit = <MaybeUninit<[T; N]>>::uninit();
// The following line coerces the pointer to the array to a pointer
// to the first array element which is equivalent.
let mut ptr = uninit.as_mut_ptr() as *mut T;
for _ in 0..N {
let decoded = T::decode(input)?;
// SAFETY: We do not read uninitialized array contents
// while initializing them.
unsafe {
ptr::write(ptr, decoded);
}
// SAFETY: Point to the next element after every iteration.
// We do this N times therefore this is safe.
ptr = unsafe { ptr.add(1) };
}
// SAFETY: All array elements have been initialized above.
let init = unsafe { uninit.assume_init() };
Ok(init)
}

// Description for the code below.
// It is not possible to transmute `[u8; N]` into `[T; N]` due to this issue:
// https://github.com/rust-lang/rust/issues/61956
//
// Workaround: Transmute `&[u8; N]` into `&[T; N]` and interpret that reference as value.
// ```
// let mut array: [u8; N] = [0; N];
// let ref_typed: &[T; N] = unsafe { mem::transmute(&array) };
// let typed: [T; N] = unsafe { ptr::read(ref_typed) };
// forget(array);
// Here `array` and `typed` points on the same memory.
// Function returns `typed` -> it is not dropped, but `array` will be dropped.
// To avoid that `array` should be forgotten.
// ```
macro_rules! decode {
( u8 ) => {{
let mut array: [u8; N] = [0; N];
input.read(&mut array[..])?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a memory leak if read returns an error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!=) @Robbepop I returned back the usage of forget

let ref_typed: &[T; N] = unsafe { mem::transmute(&array) };
let typed: [T; N] = unsafe { ptr::read(ref_typed) };
forget(array);
Ok(typed)
}};
( i8 ) => {{
let mut array: [i8; N] = [0; N];
let bytes = unsafe { mem::transmute::<&mut [i8], &mut [u8]>(&mut array[..]) };
input.read(bytes)?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same


let ref_typed: &[T; N] = unsafe { mem::transmute(&array) };
let typed: [T; N] = unsafe { ptr::read(ref_typed) };
forget(array);
Ok(typed)
}};
( $ty:ty ) => {{
if cfg!(target_endian = "little") {
let mut array: [$ty; N] = [0; N];
let bytes = <[$ty] as AsMutByteSlice<$ty>>::as_mut_byte_slice(&mut array[..]);
input.read(bytes)?;
let ref_typed: &[T; N] = unsafe { mem::transmute(&array) };
let typed: [T; N] = unsafe { ptr::read(ref_typed) };
forget(array);
Ok(typed)
} else {
general_array_decode(input)
}
}};
}

with_type_info! {
<T as Decode>::TYPE_INFO,
decode,
{
general_array_decode(input)
},
}
}

/// Decode the vec (without prepended the len).
///
/// This is equivalent to decode all elements one by one, but it is optimized in some
/// situation.
Expand Down Expand Up @@ -706,16 +794,9 @@ impl<T: Encode, const N: usize> Encode for [T; N] {
}

impl<T: Decode, const N: usize> Decode for [T; N] {
#[inline]
fn decode<I: Input>(input: &mut I) -> Result<Self, Error> {
let mut array = ArrayVec::new();
for _ in 0..N {
array.push(T::decode(input)?);
}

match array.into_inner() {
Ok(a) => Ok(a),
Err(_) => panic!("We decode `N` elements; qed"),
}
decode_array(input)
}
}

Expand Down Expand Up @@ -1647,6 +1728,25 @@ mod tests {
<[u32; 0]>::decode(&mut &encoded[..]).unwrap();
}


macro_rules! test_array_encode_and_decode {
( $( $name:ty ),* $(,)? ) => {
$(
paste::item! {
#[test]
fn [<test_array_encode_and_decode _ $name>]() {
let data: [$name; 32] = [123; 32];
let encoded = data.encode();
let decoded: [$name; 32] = Decode::decode(&mut &encoded[..]).unwrap();
assert_eq!(decoded, data);
}
}
)*
}
}

test_array_encode_and_decode!(u8, i8, u16, i16, u32, i32, u64, i64, u128, i128);

fn test_encoded_size(val: impl Encode) {
let length = val.using_encoded(|v| v.len());

Expand Down
8 changes: 4 additions & 4 deletions tests/chain-error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,23 @@ struct Wrapper<T>(T);

#[derive(Decode, Debug)]
struct StructNamed {
foo: u16
_foo: u16
}

#[derive(Decode, Debug)]
struct StructUnnamed(u16);

#[derive(Decode, Debug)]
enum E {
VariantNamed { foo: u16, },
VariantNamed { _foo: u16, },
VariantUnnamed(u16),
}

#[test]
fn full_error_struct_named() {
let encoded = vec![0];
let err = r#"Could not decode `Wrapper.0`:
Could not decode `StructNamed::foo`:
Could not decode `StructNamed::_foo`:
Not enough data to fill buffer
"#;

Expand Down Expand Up @@ -75,7 +75,7 @@ fn full_error_enum_unknown_variant() {
#[test]
fn full_error_enum_named_field() {
let encoded = vec![0, 0];
let err = r#"Could not decode `E::VariantNamed::foo`:
let err = r#"Could not decode `E::VariantNamed::_foo`:
Not enough data to fill buffer
"#;

Expand Down
10 changes: 5 additions & 5 deletions tests/max_encoded_len_ui/crate_str.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ error[E0277]: the trait bound `Example: WrapperTypeEncode` is not satisfied
3 | #[derive(Encode, MaxEncodedLen)]
| ^^^^^^^^^^^^^ the trait `WrapperTypeEncode` is not implemented for `Example`
|
::: $WORKSPACE/src/max_encoded_len.rs
|
| pub trait MaxEncodedLen: Encode {
| ------ required by this bound in `MaxEncodedLen`
|
= note: required because of the requirements on the impl of `Encode` for `Example`
note: required by a bound in `MaxEncodedLen`
--> $DIR/max_encoded_len.rs:28:26
|
28 | pub trait MaxEncodedLen: Encode {
| ^^^^^^ required by this bound in `MaxEncodedLen`
= note: this error originates in the derive macro `MaxEncodedLen` (in Nightly builds, run with -Z macro-backtrace for more info)
10 changes: 5 additions & 5 deletions tests/max_encoded_len_ui/incomplete_attr.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ error[E0277]: the trait bound `Example: WrapperTypeEncode` is not satisfied
3 | #[derive(Encode, MaxEncodedLen)]
| ^^^^^^^^^^^^^ the trait `WrapperTypeEncode` is not implemented for `Example`
|
::: $WORKSPACE/src/max_encoded_len.rs
|
| pub trait MaxEncodedLen: Encode {
| ------ required by this bound in `MaxEncodedLen`
|
= note: required because of the requirements on the impl of `Encode` for `Example`
note: required by a bound in `MaxEncodedLen`
--> $DIR/max_encoded_len.rs:28:26
|
28 | pub trait MaxEncodedLen: Encode {
| ^^^^^^ required by this bound in `MaxEncodedLen`
= note: this error originates in the derive macro `MaxEncodedLen` (in Nightly builds, run with -Z macro-backtrace for more info)
10 changes: 5 additions & 5 deletions tests/max_encoded_len_ui/missing_crate_specifier.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ error[E0277]: the trait bound `Example: WrapperTypeEncode` is not satisfied
3 | #[derive(Encode, MaxEncodedLen)]
| ^^^^^^^^^^^^^ the trait `WrapperTypeEncode` is not implemented for `Example`
|
::: $WORKSPACE/src/max_encoded_len.rs
|
| pub trait MaxEncodedLen: Encode {
| ------ required by this bound in `MaxEncodedLen`
|
= note: required because of the requirements on the impl of `Encode` for `Example`
note: required by a bound in `MaxEncodedLen`
--> $DIR/max_encoded_len.rs:28:26
|
28 | pub trait MaxEncodedLen: Encode {
| ^^^^^^ required by this bound in `MaxEncodedLen`
= note: this error originates in the derive macro `MaxEncodedLen` (in Nightly builds, run with -Z macro-backtrace for more info)
10 changes: 5 additions & 5 deletions tests/max_encoded_len_ui/not_encode.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ error[E0277]: the trait bound `NotEncode: WrapperTypeEncode` is not satisfied
3 | #[derive(MaxEncodedLen)]
| ^^^^^^^^^^^^^ the trait `WrapperTypeEncode` is not implemented for `NotEncode`
|
::: $WORKSPACE/src/max_encoded_len.rs
|
| pub trait MaxEncodedLen: Encode {
| ------ required by this bound in `MaxEncodedLen`
|
= note: required because of the requirements on the impl of `Encode` for `NotEncode`
note: required by a bound in `MaxEncodedLen`
--> $DIR/max_encoded_len.rs:28:26
|
28 | pub trait MaxEncodedLen: Encode {
| ^^^^^^ required by this bound in `MaxEncodedLen`
= note: this error originates in the derive macro `MaxEncodedLen` (in Nightly builds, run with -Z macro-backtrace for more info)