Skip to content

Commit

Permalink
[derive] Automatically derive super-traits
Browse files Browse the repository at this point in the history
`#[derive(FromBytes)]` implies `#[derive(FromZeros)]`, which implies
`#[derive(TryFromBytes)]`.

Closes #925
  • Loading branch information
joshlf committed Feb 23, 2024
1 parent ca49473 commit 11ced18
Show file tree
Hide file tree
Showing 14 changed files with 114 additions and 90 deletions.
6 changes: 3 additions & 3 deletions src/byteorder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@
//!
//! ```rust,edition2021
//! # #[cfg(feature = "derive")] { // This example uses derives, and won't compile without them
//! use zerocopy::{IntoBytes, ByteSlice, FromBytes, FromZeros, NoCell, Ref, Unaligned};
//! use zerocopy::{IntoBytes, ByteSlice, FromBytes, NoCell, Ref, Unaligned};
//! use zerocopy::byteorder::network_endian::U16;
//!
//! #[derive(FromZeros, FromBytes, IntoBytes, NoCell, Unaligned)]
//! #[derive(FromBytes, IntoBytes, NoCell, Unaligned)]
//! #[repr(C)]
//! struct UdpHeader {
//! src_port: U16,
Expand Down Expand Up @@ -357,7 +357,7 @@ example of how it can be used for parsing UDP packets.
[`IntoBytes`]: crate::IntoBytes
[`Unaligned`]: crate::Unaligned"),
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
#[cfg_attr(any(feature = "derive", test), derive(KnownLayout, NoCell, TryFromBytes, FromZeros, FromBytes, IntoBytes, Unaligned))]
#[cfg_attr(any(feature = "derive", test), derive(KnownLayout, NoCell, FromBytes, IntoBytes, Unaligned))]
#[repr(transparent)]
pub struct $name<O>([u8; $bytes], PhantomData<O>);
}
Expand Down
12 changes: 5 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5766,9 +5766,7 @@ mod tests {
//
// This is used to test the custom derives of our traits. The `[u8]` type
// gets a hand-rolled impl, so it doesn't exercise our custom derives.
#[derive(
Debug, Eq, PartialEq, TryFromBytes, FromZeros, FromBytes, IntoBytes, Unaligned, NoCell,
)]
#[derive(Debug, Eq, PartialEq, FromBytes, IntoBytes, Unaligned, NoCell)]
#[repr(transparent)]
struct Unsized([u8]);

Expand Down Expand Up @@ -7896,7 +7894,7 @@ mod tests {
assert_eq!(too_many_bytes[0], 123);
}

#[derive(Debug, Eq, PartialEq, TryFromBytes, FromZeros, FromBytes, IntoBytes, NoCell)]
#[derive(Debug, Eq, PartialEq, FromBytes, IntoBytes, NoCell)]
#[repr(C)]
struct Foo {
a: u32,
Expand Down Expand Up @@ -7925,7 +7923,7 @@ mod tests {

#[test]
fn test_array() {
#[derive(TryFromBytes, FromZeros, FromBytes, IntoBytes, NoCell)]
#[derive(FromBytes, IntoBytes, NoCell)]
#[repr(C)]
struct Foo {
a: [u16; 33],
Expand Down Expand Up @@ -7989,7 +7987,7 @@ mod tests {

#[test]
fn test_transparent_packed_generic_struct() {
#[derive(IntoBytes, TryFromBytes, FromZeros, FromBytes, Unaligned)]
#[derive(IntoBytes, FromBytes, Unaligned)]
#[repr(transparent)]
struct Foo<T> {
_t: T,
Expand All @@ -7999,7 +7997,7 @@ mod tests {
assert_impl_all!(Foo<u32>: FromZeros, FromBytes, IntoBytes);
assert_impl_all!(Foo<u8>: Unaligned);

#[derive(IntoBytes, TryFromBytes, FromZeros, FromBytes, Unaligned)]
#[derive(IntoBytes, FromBytes, Unaligned)]
#[repr(packed)]
struct Bar<T, U> {
_t: T,
Expand Down
6 changes: 1 addition & 5 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,6 @@ pub(crate) mod testutil {
#[derive(
KnownLayout,
NoCell,
TryFromBytes,
FromZeros,
FromBytes,
IntoBytes,
Eq,
Expand Down Expand Up @@ -249,9 +247,7 @@ pub(crate) mod testutil {
}
}

#[derive(
NoCell, FromZeros, FromBytes, Eq, PartialEq, Ord, PartialOrd, Default, Debug, Copy, Clone,
)]
#[derive(NoCell, FromBytes, Eq, PartialEq, Ord, PartialOrd, Default, Debug, Copy, Clone)]
#[repr(C)]
pub(crate) struct Nested<T, U: ?Sized> {
_t: T,
Expand Down
2 changes: 1 addition & 1 deletion src/wrappers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ use super::*;
#[derive(Default, Copy)]
#[cfg_attr(
any(feature = "derive", test),
derive(NoCell, KnownLayout, TryFromBytes, FromZeros, FromBytes, IntoBytes, Unaligned)
derive(NoCell, KnownLayout, FromBytes, IntoBytes, Unaligned)
)]
#[repr(C, packed)]
pub struct Unalign<T>(T);
Expand Down
97 changes: 68 additions & 29 deletions zerocopy-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,16 @@ pub fn derive_try_from_bytes(ts: proc_macro::TokenStream) -> proc_macro::TokenSt

#[proc_macro_derive(FromZeros)]
pub fn derive_from_zeros(ts: proc_macro::TokenStream) -> proc_macro::TokenStream {
let try_from_bytes = derive_try_from_bytes(ts.clone());

let ast = syn::parse_macro_input!(ts as DeriveInput);
match &ast.data {
let from_zeros = match &ast.data {
Data::Struct(strct) => derive_from_zeros_struct(&ast, strct),
Data::Enum(enm) => derive_from_zeros_enum(&ast, enm),
Data::Union(unn) => derive_from_zeros_union(&ast, unn),
}
.into()
.into();
IntoIterator::into_iter([try_from_bytes, from_zeros]).collect()
}

/// Deprecated: prefer [`FromZeros`] instead.
Expand All @@ -314,13 +317,17 @@ pub fn derive_from_zeroes(ts: proc_macro::TokenStream) -> proc_macro::TokenStrea

#[proc_macro_derive(FromBytes)]
pub fn derive_from_bytes(ts: proc_macro::TokenStream) -> proc_macro::TokenStream {
let from_zeros = derive_from_zeros(ts.clone());

let ast = syn::parse_macro_input!(ts as DeriveInput);
match &ast.data {
let from_bytes = match &ast.data {
Data::Struct(strct) => derive_from_bytes_struct(&ast, strct),
Data::Enum(enm) => derive_from_bytes_enum(&ast, enm),
Data::Union(unn) => derive_from_bytes_union(&ast, unn),
}
.into()
.into();

IntoIterator::into_iter([from_zeros, from_bytes]).collect()
}

#[proc_macro_derive(IntoBytes)]
Expand Down Expand Up @@ -447,25 +454,33 @@ fn derive_try_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2:
.to_compile_error();
}

// We don't actually care what the repr is; we just care that it's one of
// the allowed ones.
try_or_print!(ENUM_TRY_FROM_BYTES_CFG.validate_reprs(ast));
let reprs = try_or_print!(ENUM_TRY_FROM_BYTES_CFG.validate_reprs(ast));

// Figure out whether the enum could in theory implement `FromBytes`.
let from_bytes = enum_size_from_repr(reprs.as_slice())
.map(|size| {
// As of this writing, `enm.is_fieldless()` is redundant since we've
// already checked for it and returned if the check failed. However, if
// we ever remove that check, then without a similar check here, this
// code would become unsound.
enm.is_fieldless() && enm.variants.len() == 1usize << size
})
.unwrap_or(false);

let variant_names = enm.variants.iter().map(|v| &v.ident);
let extras = Some(quote!(
// SAFETY: We use `is_bit_valid` to validate that the bit pattern
// corresponds to one of the field-less enum's variant discriminants.
// Thus, this is a sound implementation of `is_bit_valid`.
fn is_bit_valid(
candidate: ::zerocopy::Ptr<
'_,
Self,
(
::zerocopy::pointer::invariant::Shared,
::zerocopy::pointer::invariant::AnyAlignment,
::zerocopy::pointer::invariant::Initialized,
),
>,
) -> ::zerocopy::macro_util::core_reexport::primitive::bool {
let is_bit_valid_body = if from_bytes {
// If the enum could implement `FromBytes`, we can avoid emitting a
// match statement. This is faster to compile, and generates code which
// performs better.
quote!({
// Prevent an "unused" warning.
let _ = candidate;
// SAFETY: If the enum could implement `FromBytes`, then all bit
// patterns are valid. Thus, this is a sound implementation.
true
})
} else {
quote!(
use ::zerocopy::macro_util::core_reexport;
// SAFETY:
// - `cast` is implemented as required.
Expand Down Expand Up @@ -499,6 +514,25 @@ fn derive_try_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2:
// `candidate` refers to a bit-valid `Self`.
discriminant == d
})*
)
};

let extras = Some(quote!(
// SAFETY: We use `is_bit_valid` to validate that the bit pattern
// corresponds to one of the field-less enum's variant discriminants.
// Thus, this is a sound implementation of `is_bit_valid`.
fn is_bit_valid(
candidate: ::zerocopy::Ptr<
'_,
Self,
(
::zerocopy::pointer::invariant::Shared,
::zerocopy::pointer::invariant::AnyAlignment,
::zerocopy::pointer::invariant::Initialized,
),
>,
) -> ::zerocopy::macro_util::core_reexport::primitive::bool {
#is_bit_valid_body
}
));
impl_block(ast, enm, Trait::TryFromBytes, FieldBounds::ALL_SELF, SelfBounds::None, None, extras)
Expand Down Expand Up @@ -608,13 +642,9 @@ fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Tok

let reprs = try_or_print!(ENUM_FROM_BYTES_CFG.validate_reprs(ast));

let variants_required = match reprs.as_slice() {
[EnumRepr::U8] | [EnumRepr::I8] => 1usize << 8,
[EnumRepr::U16] | [EnumRepr::I16] => 1usize << 16,
// `validate_reprs` has already validated that it's one of the preceding
// patterns.
_ => unreachable!(),
};
let variants_required = 1usize
<< enum_size_from_repr(reprs.as_slice())
.expect("internal error: `validate_reprs` has already validated that the reprs guarantee the enum's size");
if enm.variants.len() != variants_required {
return Error::new_spanned(
ast,
Expand All @@ -629,6 +659,15 @@ fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Tok
impl_block(ast, enm, Trait::FromBytes, FieldBounds::ALL_SELF, SelfBounds::None, None, None)
}

// Returns `None` if the enum's size is not guaranteed by the repr.
fn enum_size_from_repr(reprs: &[EnumRepr]) -> Option<usize> {
match reprs {
[EnumRepr::U8] | [EnumRepr::I8] => Some(8),
[EnumRepr::U16] | [EnumRepr::I16] => Some(16),
_ => None,
}
}

#[rustfmt::skip]
const ENUM_FROM_BYTES_CFG: Config<EnumRepr> = {
use EnumRepr::*;
Expand Down
12 changes: 6 additions & 6 deletions zerocopy-derive/tests/enum_from_bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ include!("include.rs");
// `Variant128` has a discriminant of -128) since Rust won't automatically wrap
// a signed discriminant around without you explicitly telling it to.

#[derive(imp::FromZeros, imp::FromBytes)]
#[derive(imp::FromBytes)]
#[repr(u8)]
enum FooU8 {
Variant0,
Expand Down Expand Up @@ -292,7 +292,7 @@ enum FooU8 {

util_assert_impl_all!(FooU8: imp::FromBytes);

#[derive(imp::FromZeros, imp::FromBytes)]
#[derive(imp::FromBytes)]
#[repr(i8)]
enum FooI8 {
Variant0,
Expand Down Expand Up @@ -555,7 +555,7 @@ enum FooI8 {

util_assert_impl_all!(FooI8: imp::FromBytes);

#[derive(imp::FromZeros, imp::FromBytes)]
#[derive(imp::FromBytes)]
#[repr(u8, align(2))]
enum FooU8Align {
Variant0,
Expand Down Expand Up @@ -818,7 +818,7 @@ enum FooU8Align {

util_assert_impl_all!(FooU8Align: imp::FromBytes);

#[derive(imp::FromZeros, imp::FromBytes)]
#[derive(imp::FromBytes)]
#[repr(i8, align(2))]
enum FooI8Align {
Variant0,
Expand Down Expand Up @@ -1081,7 +1081,7 @@ enum FooI8Align {

util_assert_impl_all!(FooI8Align: imp::FromBytes);

#[derive(imp::FromZeros, imp::FromBytes)]
#[derive(imp::FromBytes)]
#[repr(u16)]
enum FooU16 {
Variant0,
Expand Down Expand Up @@ -66624,7 +66624,7 @@ enum FooU16 {

util_assert_impl_all!(FooU16: imp::FromBytes);

#[derive(imp::FromZeros, imp::FromBytes)]
#[derive(imp::FromBytes)]
#[repr(i16)]
enum FooI16 {
Variant0,
Expand Down
9 changes: 1 addition & 8 deletions zerocopy-derive/tests/hygiene.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,7 @@ include!("include.rs");

extern crate zerocopy as _zerocopy;

// #[macro_use]
// mod util;

// use std::{marker::PhantomData, option::IntoIter};

#[derive(
_zerocopy::KnownLayout, _zerocopy::FromZeros, _zerocopy::FromBytes, _zerocopy::Unaligned,
)]
#[derive(_zerocopy::KnownLayout, _zerocopy::FromBytes, _zerocopy::Unaligned)]
#[repr(C)]
struct TypeParams<'a, T, I: imp::Iterator> {
a: T,
Expand Down
2 changes: 0 additions & 2 deletions zerocopy-derive/tests/include.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ pub mod util {
#[derive(
super::imp::KnownLayout,
super::imp::NoCell,
super::imp::TryFromBytes,
super::imp::FromZeros,
super::imp::FromBytes,
super::imp::IntoBytes,
Copy,
Expand Down
6 changes: 3 additions & 3 deletions zerocopy-derive/tests/paths_and_modules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ include!("include.rs");
mod foo {
use super::*;

#[derive(imp::FromZeros, imp::FromBytes, imp::IntoBytes, imp::Unaligned)]
#[derive(imp::FromBytes, imp::IntoBytes, imp::Unaligned)]
#[repr(C)]
pub struct Foo {
foo: u8,
}

#[derive(imp::FromZeros, imp::FromBytes, imp::IntoBytes, imp::Unaligned)]
#[derive(imp::FromBytes, imp::IntoBytes, imp::Unaligned)]
#[repr(C)]
pub struct Bar {
bar: u8,
Expand All @@ -32,7 +32,7 @@ mod foo {

use foo::Foo;

#[derive(imp::FromZeros, imp::FromBytes, imp::IntoBytes, imp::Unaligned)]
#[derive(imp::FromBytes, imp::IntoBytes, imp::Unaligned)]
#[repr(C)]
struct Baz {
foo: Foo,
Expand Down
Loading

0 comments on commit 11ced18

Please sign in to comment.