diff --git a/Cargo.toml b/Cargo.toml index 9d6be11..6fb51ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,8 +19,12 @@ extern_crate_std = ["extern_crate_alloc"] zeroable_maybe_uninit = [] zeroable_atomics = [] min_const_generics = [] -wasm_simd = [] # Until >= 1.54.0 is MSRV this is an off-by-default feature. -aarch64_simd = [] # Until >= 1.59.0 is MSRV this is an off-by-default feature. + +# Off-by-default until bytemuck's MSRV is raised. +wasm_simd = [] # Requires MSRV >= 1.54.0. +unified_cast = [] # Requires MSRV >= 1.57.0. +aarch64_simd = [] # Requires MSRV >= 1.59.0. +non_null_slice_cast = [] # Requires MSRV >= 1.63.0. # Do not use if you can avoid it, because this is unsound. unsound_ptr_pod_impl = [] @@ -42,6 +46,8 @@ features = [ "zeroable_atomics", "min_const_generics", "wasm_simd", + "unified_cast", + "non_null_slice_cast", ] [package.metadata.playground] @@ -54,4 +60,6 @@ features = [ "zeroable_atomics", "min_const_generics", "wasm_simd", + "unified_cast", + "non_null_slice_cast", ] diff --git a/src/cast.rs b/src/cast.rs new file mode 100644 index 0000000..2430845 --- /dev/null +++ b/src/cast.rs @@ -0,0 +1,1277 @@ +//! Safe byte-wise conversion between types. +//! +//! The public interface consist of a few traits: +//! * `Reinterpret` and `TryReinterpret` implement by-value conversion. +//! * `ReinterpretInner` and `TryReinterpretInner` implement in-place conversion +//! for various container types. +//! +//! `ReinterpretInner` is backed by several helper traits: +//! * `Container` is a concrete container with a pointer to zero or more items. +//! This trait handles several things: +//! * Holding a tag type representing the containers type class. This prevents +//! casts between incompatible containers and allows any additional +//! constraints required by the container not handled by it's raw +//! representation. +//! * Conversion to and from a container's raw form. This allows the actual +//! cast implementation to be shared between multiple containers. +//! * Whether the original value should be returned in case of an error. This +//! allows simplifying the return type of `try_reinterpret_inner` for `Copy` +//! types. +//! Examples of containers include references, raw pointers, `Box` and +//! `Vec`. +//! * `AssertClassContraints` handles any additional constraints a container +//! places on its items. +//! * `RawPointer` Gives a more unified interface for handling `T`, `[T]` and +//! `str` for `Container` impls. +//! * `CastRaw` implements the conversion between a container's raw forms. This +//! is responsible for verifying that the container's item types are +//! compatible and the item's size/alignment constraints are met. +//! * `TryCastRaw` is the falliable form of `CastRaw` + +use core::{ + marker::{PhantomData, Unpin}, + mem::{align_of, size_of, ManuallyDrop}, + ops::Deref, + pin::Pin, + ptr::{self, NonNull}, + sync::atomic::AtomicPtr, +}; + +#[cfg(feature = "extern_crate_alloc")] +use alloc::{ + borrow::{Cow, ToOwned}, + boxed::Box, + rc::{Rc, Weak as RcWeak}, + sync::{Arc, Weak as ArcWeak}, + vec::Vec, +}; + +use crate::{ + static_assert::*, AnyBitPattern, CheckedBitPattern, NoUninit, PodCastError, +}; + +/// Safe byte-wise conversion of a value. +/// +/// This requires the `unified_cast` feature to be enabled and a rust version +/// `>=1.57`. +pub trait Reinterpret: Sized { + /// Performs the conversion. + fn reinterpret(self) -> T; +} +impl Reinterpret for T +where + T: NoUninit, + U: AnyBitPattern, +{ + fn reinterpret(self) -> U { + static_assert!(AssertSameSize(T, U)); + // SAFETY: + // There are no uninitialized bytes in the source type. + // All bit patterns are accepted by the target type. + // Both types have the same size. + unsafe { (&ManuallyDrop::new(self) as *const _ as *const U).read() } + } +} + +/// Safe byte-wise conversion of a value which may fail due to the source value +/// not being a valid bit pattern for the target type. +/// +/// This requires the `unified_cast` feature to be enabled and a rust version +/// `>=1.57`. +pub trait TryReinterpret: Sized { + /// Performs the conversion. + fn try_reinterpret(self) -> Option; +} +impl TryReinterpret for T +where + T: NoUninit, + U: CheckedBitPattern, +{ + fn try_reinterpret(self) -> Option { + static_assert!(AssertSameSize(T, U)); + let bits: U::Bits = self.reinterpret(); + if U::is_valid_bit_pattern(&bits) { + // SAFETY: + // There are no uninitialized bytes in the source type. + // The value has been confirmed to be a valid bit pattern for the target type. + // Both types have the same size. + Some(unsafe { (&ManuallyDrop::new(self) as *const _ as *const U).read() }) + } else { + None + } + } +} + +// Container classes. +pub struct RefT; +pub struct PtrT; +pub struct NonNullT; +pub struct AtomicPtrT; +pub struct OptionT(PhantomData); +pub struct PinT(PhantomData); +#[cfg(feature = "extern_crate_alloc")] +pub struct BoxT; +#[cfg(feature = "extern_crate_alloc")] +pub struct CowT; +#[cfg(feature = "extern_crate_alloc")] +pub struct RcT; +#[cfg(feature = "extern_crate_alloc")] +pub struct RcWeakT; +#[cfg(feature = "extern_crate_alloc")] +pub struct ArcT; +#[cfg(feature = "extern_crate_alloc")] +pub struct ArcWeakT; +#[cfg(feature = "extern_crate_alloc")] +pub struct VecT; + +/// Policy trait for whether the original value should be returned with the +/// error. +pub trait CastErrWithValue { + /// The error type returned. + type Err; + /// Combines the original value with the error. + fn cast_error_with_value(err: PodCastError, value: T) -> Self::Err; +} +/// Return the error without the original value. +pub struct OnlyErr; +impl CastErrWithValue for OnlyErr { + type Err = PodCastError; + fn cast_error_with_value(err: PodCastError, _: T) -> Self::Err { + err + } +} +/// Return both the error and the original value. +pub struct ErrWithValue; +impl CastErrWithValue for ErrWithValue { + type Err = (PodCastError, T); + fn cast_error_with_value(err: PodCastError, value: T) -> Self::Err { + (err, value) + } +} + +/// Like `*const [T]`, but the length can be retrieved safely. +pub struct RawSlice { + data: *const T, + len: usize, +} +impl Clone for RawSlice { + fn clone(&self) -> Self { + *self + } +} +impl Copy for RawSlice {} + +/// Like `*mut [T]`, but the length can be retrieved safely. +pub struct RawMutSlice { + data: *mut T, + len: usize, +} +impl Clone for RawMutSlice { + fn clone(&self) -> Self { + *self + } +} +impl Copy for RawMutSlice {} + +/// A single byte from a `str` slice. +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct StrByte(u8); +// SAFETY: A transparent wrapper over a single byte has no uninitialized bytes. +unsafe impl NoUninit for StrByte {} + +/// Converts between a pointer type and it's raw form. This allows treating +/// various DST pointers similarly. +/// +/// # Safety +/// Converting a pointer to and from it's raw form must result in the same +/// pointer. +pub unsafe trait RawPtr { + /// The raw form of the pointer. + type Raw: Copy; + + /// Performs the conversion to the raw form. + /// + /// # Safety + /// For DST targets the pointer must be safe to materialize. If the + /// `non_null_slice_cast` feature is enabled, then it must only be non-null. + unsafe fn into_raw_ptr(self) -> Self::Raw; + + /// Performs the conversion from the raw form. + /// + /// # Safety + /// The value must have been created from `into_raw_ptr` and possibly + /// converted to a different type using `CastRaw` or `TryCastRaw`. + unsafe fn from_raw_ptr(raw: Self::Raw) -> Self; +} +unsafe impl RawPtr for *const T { + type Raw = *const T; + unsafe fn into_raw_ptr(self) -> Self::Raw { + self + } + unsafe fn from_raw_ptr(raw: Self::Raw) -> Self { + raw + } +} +unsafe impl RawPtr for *mut T { + type Raw = *mut T; + unsafe fn into_raw_ptr(self) -> Self::Raw { + self + } + unsafe fn from_raw_ptr(raw: Self::Raw) -> Self { + raw + } +} +unsafe impl RawPtr for *const [T] { + type Raw = RawSlice; + #[cfg(feature = "non_null_slice_cast")] + unsafe fn into_raw_ptr(self) -> Self::Raw { + let len = NonNull::new_unchecked(self as *mut [T]).len(); + RawSlice { data: self as *const T, len } + } + #[cfg(not(feature = "non_null_slice_cast"))] + unsafe fn into_raw_ptr(self) -> Self::Raw { + let len = (*self).len(); + RawSlice { data: self as *const T, len } + } + unsafe fn from_raw_ptr(raw: Self::Raw) -> Self { + ptr::slice_from_raw_parts(raw.data, raw.len) + } +} +unsafe impl RawPtr for *mut [T] { + type Raw = RawMutSlice; + #[cfg(feature = "non_null_slice_cast")] + unsafe fn into_raw_ptr(self) -> Self::Raw { + let len = NonNull::new_unchecked(self).len(); + RawMutSlice { data: self as *mut T, len } + } + #[cfg(not(feature = "non_null_slice_cast"))] + unsafe fn into_raw_ptr(self) -> Self::Raw { + let len = (*self).len(); + RawMutSlice { data: self as *mut T, len } + } + unsafe fn from_raw_ptr(raw: Self::Raw) -> Self { + ptr::slice_from_raw_parts_mut(raw.data, raw.len) + } +} +unsafe impl RawPtr for *const str { + type Raw = RawSlice; + unsafe fn into_raw_ptr(self) -> Self::Raw { + (self as *const [StrByte]).into_raw_ptr() + } + unsafe fn from_raw_ptr(raw: Self::Raw) -> Self { + <*const [StrByte]>::from_raw_ptr(raw) as *const str + } +} +unsafe impl RawPtr for *mut str { + type Raw = RawMutSlice; + unsafe fn into_raw_ptr(self) -> Self::Raw { + (self as *mut [StrByte]).into_raw_ptr() + } + unsafe fn from_raw_ptr(raw: Self::Raw) -> Self { + <*mut [StrByte]>::from_raw_ptr(raw) as *mut str + } +} + +/// A concrete container type. e.g `&'a T` or `Box` +/// +/// # Safety +/// * `Class` must be set such that calling `CastRaw` on this containers raw +/// form is valid. +/// * `Item` must match the contained item type. +/// * `into_raw` -> `CastRaw`/`TryCastRaw` -> `from_raw` must be valid. +pub unsafe trait Container<'a>: Sized { + /// The type class of this container. Used to limit which raw casts should be + type Class: AssertClassContraints; + /// The item type held within this container. + type Item: 'a + ?Sized + ItemLayout; + /// The 'raw' form of this container. Used to allow different containers to + /// share the same `CastRaw` and `TryCastRaw` impls. + type Raw: 'a + Copy; + /// Whether the cast should return the original value along with the error. + type CastErr: CastErrWithValue; + + /// Converts the container into it's raw form. + fn into_raw(self) -> Self::Raw; + + /// Reconstructs the container from it's raw form. + /// + /// # Safety + /// The values must have to come from `into_parts` of the same container + /// class. Casting to a different item type must meet the following + /// constraints: + /// * Casting between zero-sized types and non-zero-sized types is forbidden. + /// * The data pointer's alignment must meet the alignment constraints of it's + /// item type. + /// * Size and alignment requirements of the container class must be met. + /// * The additional data must be adjusted for the new type. + unsafe fn from_raw(raw: Self::Raw) -> Self; +} + +unsafe impl<'a, T> Container<'a> for &'a T +where + T: 'a + ?Sized + ItemLayout, + *const T: RawPtr, +{ + type Class = RefT; + type Item = T; + type Raw = <*const T as RawPtr>::Raw; + type CastErr = OnlyErr; + + fn into_raw(self) -> Self::Raw { + // SAFETY: Materializing pointers backed by a reference is safe. + unsafe { (self as *const T).into_raw_ptr() } + } + unsafe fn from_raw(raw: Self::Raw) -> Self { + &*<*const T>::from_raw_ptr(raw) + } +} + +unsafe impl<'a, T> Container<'a> for &'a mut T +where + T: 'a + ?Sized + ItemLayout, + *mut T: RawPtr, +{ + type Class = RefT; + type Item = T; + type Raw = <*mut T as RawPtr>::Raw; + type CastErr = OnlyErr; + + fn into_raw(self) -> Self::Raw { + // SAFETY: Materializing pointers backed by a reference is safe. + unsafe { (self as *mut T).into_raw_ptr() } + } + unsafe fn from_raw(raw: Self::Raw) -> Self { + &mut *<*mut T>::from_raw_ptr(raw) + } +} + +// No safe way to get the length of a slice. Only implement for sized types. +unsafe impl<'a, T: 'a> Container<'a> for *const T { + type Class = PtrT; + type Item = T; + type Raw = Self; + type CastErr = OnlyErr; + + fn into_raw(self) -> Self::Raw { + self + } + unsafe fn from_raw(raw: Self::Raw) -> Self { + raw + } +} + +// No safe way to get the length of a slice. Only implement for sized types. +unsafe impl<'a, T: 'a> Container<'a> for *mut T { + type Class = PtrT; + type Item = T; + type Raw = Self; + type CastErr = OnlyErr; + + fn into_raw(self) -> Self::Raw { + self + } + unsafe fn from_raw(raw: Self::Raw) -> Self { + raw + } +} + +unsafe impl< + 'a, + #[cfg(feature = "non_null_slice_cast")] T: 'a + ?Sized + ItemLayout, + #[cfg(not(feature = "non_null_slice_cast"))] T: 'a, + > Container<'a> for NonNull +where + *mut T: RawPtr, +{ + type Class = NonNullT; + type Item = T; + type Raw = <*mut T as RawPtr>::Raw; + type CastErr = OnlyErr; + + fn into_raw(self) -> Self::Raw { + // SAFETY: We only attempt to get the length with the `non_null_slice_cast` + // feature. + unsafe { self.as_ptr().into_raw_ptr() } + } + unsafe fn from_raw(raw: Self::Raw) -> Self { + NonNull::new_unchecked(<*mut T>::from_raw_ptr(raw)) + } +} + +unsafe impl<'a, T: 'a> Container<'a> for AtomicPtr { + type Class = AtomicPtrT; + type Item = T; + type Raw = *mut T; + type CastErr = OnlyErr; + + fn into_raw(self) -> Self::Raw { + self.into_inner() + } + unsafe fn from_raw(raw: Self::Raw) -> Self { + Self::new(raw) + } +} + +unsafe impl<'a, C> Container<'a> for Option +where + C: Container<'a>, + C::CastErr: CastErrWithValue, +{ + type Class = OptionT; + type Item = C::Item; + type Raw = Option; + type CastErr = C::CastErr; + + fn into_raw(self) -> Self::Raw { + self.map(|x| x.into_raw()) + } + unsafe fn from_raw(raw: Self::Raw) -> Self { + raw.map(|raw| C::from_raw(raw)) + } +} + +// SAFETY: `Pin` has no safety requirements for types which deref to an `Unpin` +// type. +unsafe impl<'a, C> Container<'a> for Pin +where + C: Container<'a> + Deref>::Item>, + C::Item: Unpin, + C::CastErr: CastErrWithValue, +{ + type Class = PinT; + type Item = C::Item; + type Raw = C::Raw; + type CastErr = C::CastErr; + + fn into_raw(self) -> Self::Raw { + Self::into_inner(self).into_raw() + } + unsafe fn from_raw(raw: Self::Raw) -> Self { + Self::new(C::from_raw(raw)) + } +} + +#[cfg(feature = "extern_crate_alloc")] +unsafe impl<'a, T> Container<'a> for Box +where + T: 'a + ?Sized + ItemLayout, + *const T: RawPtr, +{ + type Class = BoxT; + type Item = T; + // Uses `*const T` as the old value can't be read after the conversion. + type Raw = <*const T as RawPtr>::Raw; + type CastErr = ErrWithValue; + + fn into_raw(self) -> Self::Raw { + // SAFETY: Materializing a pointer to an allocated box is safe. + unsafe { (Self::into_raw(self) as *const T).into_raw_ptr() } + } + unsafe fn from_raw(raw: Self::Raw) -> Self { + Self::from_raw(<*const T>::from_raw_ptr(raw) as *mut T) + } +} + +#[cfg(feature = "extern_crate_alloc")] +unsafe impl<'a, T> Container<'a> for Rc +where + T: 'a + ?Sized + ItemLayout, + *const T: RawPtr, +{ + type Class = RcT; + type Item = T; + type Raw = <*const T as RawPtr>::Raw; + type CastErr = ErrWithValue; + + fn into_raw(self) -> Self::Raw { + // SAFETY: Materializing a pointer to an allocated `Rc` is safe. + unsafe { Self::into_raw(self).into_raw_ptr() } + } + unsafe fn from_raw(raw: Self::Raw) -> Self { + Self::from_raw(<*const T>::from_raw_ptr(raw)) + } +} +#[cfg(feature = "extern_crate_alloc")] +unsafe impl<'a, T: 'a> Container<'a> for RcWeak { + type Class = RcWeakT; + type Item = T; + // `get_mut_unchecked` requires no other `Rc`s with a different type exist. + type Raw = *const T; + type CastErr = ErrWithValue; + + fn into_raw(self) -> Self::Raw { + Self::into_raw(self) + } + unsafe fn from_raw(raw: Self::Raw) -> Self { + Self::from_raw(raw) + } +} + +#[cfg(feature = "extern_crate_alloc")] +unsafe impl<'a, T> Container<'a> for Arc +where + T: 'a + ?Sized + ItemLayout, + *const T: RawPtr, +{ + type Class = ArcT; + type Item = T; + type Raw = <*const T as RawPtr>::Raw; + type CastErr = ErrWithValue; + + fn into_raw(self) -> Self::Raw { + // SAFETY: Materializing a pointer to an allocated `Arc` is safe. + unsafe { Self::into_raw(self).into_raw_ptr() } + } + unsafe fn from_raw(raw: Self::Raw) -> Self { + Self::from_raw(<*const T>::from_raw_ptr(raw)) + } +} +#[cfg(feature = "extern_crate_alloc")] +unsafe impl<'a, T: 'a> Container<'a> for ArcWeak { + type Class = ArcWeakT; + type Item = T; + type Raw = *mut T; + type CastErr = ErrWithValue; + + fn into_raw(self) -> Self::Raw { + Self::into_raw(self) as *mut T + } + unsafe fn from_raw(raw: Self::Raw) -> Self { + Self::from_raw(raw) + } +} + +/// The raw form of a vec. +#[cfg(feature = "extern_crate_alloc")] +#[derive(Clone, Copy)] +pub struct RawVec { + slice: RawSlice, + cap: usize, +} +#[cfg(feature = "extern_crate_alloc")] +unsafe impl<'a, T: 'a + Copy> Container<'a> for Vec { + type Class = VecT; + type Item = T; + type Raw = RawVec; + type CastErr = ErrWithValue; + + fn into_raw(self) -> Self::Raw { + let mut x = ManuallyDrop::new(self); + RawVec { + // Use `as_mut_ptr` to get the correct provenance. + slice: RawSlice { data: x.as_mut_ptr() as *const T, len: x.len() }, + cap: x.capacity(), + } + } + unsafe fn from_raw(raw: Self::Raw) -> Self { + Self::from_raw_parts(raw.slice.data as *mut T, raw.slice.len, raw.cap) + } +} + +/// Conversion between a types `Owned` type, and that type's raw form. +#[cfg(feature = "extern_crate_alloc")] +pub trait RawToOwned: ToOwned { + type RawOwned: Copy; + fn raw_from_owned(value: Self::Owned) -> Self::RawOwned; + unsafe fn owned_from_raw(raw: Self::RawOwned) -> Self::Owned; +} +#[cfg(feature = "extern_crate_alloc")] +impl RawToOwned for T { + type RawOwned = T; + fn raw_from_owned(value: Self::Owned) -> Self::RawOwned { + value + } + unsafe fn owned_from_raw(raw: Self::RawOwned) -> Self::Owned { + raw + } +} +#[cfg(feature = "extern_crate_alloc")] +impl RawToOwned for [T] { + type RawOwned = RawVec; + fn raw_from_owned(value: Self::Owned) -> Self::RawOwned { + value.into_raw() + } + unsafe fn owned_from_raw(raw: Self::RawOwned) -> Self::Owned { + Vec::from_raw(raw) + } +} + +/// The raw form of a `Cow`. +#[cfg(feature = "extern_crate_alloc")] +pub enum RawCow +where + *const T: RawPtr, +{ + Borrowed(<*const T as RawPtr>::Raw), + Owned(::RawOwned), +} +#[cfg(feature = "extern_crate_alloc")] +impl Clone for RawCow +where + *const T: RawPtr, +{ + fn clone(&self) -> Self { + *self + } +} +#[cfg(feature = "extern_crate_alloc")] +impl Copy for RawCow where *const T: RawPtr {} + +#[cfg(feature = "extern_crate_alloc")] +unsafe impl<'a, T> Container<'a> for Cow<'a, T> +where + T: 'a + ?Sized + RawToOwned + ItemLayout, + *const T: RawPtr, +{ + type Class = CowT; + type Item = T; + type Raw = RawCow; + type CastErr = ErrWithValue; + + fn into_raw(self) -> Self::Raw { + match self { + Self::Borrowed(x) => { + // SAFETY: Materializing a pointer backed by a reference is safe. + RawCow::Borrowed(unsafe { (x as *const T).into_raw_ptr() }) + } + Self::Owned(x) => RawCow::Owned(::raw_from_owned(x)), + } + } + unsafe fn from_raw(raw: Self::Raw) -> Self { + match raw { + RawCow::Borrowed(x) => Self::Borrowed(&*<*const T>::from_raw_ptr(x)), + RawCow::Owned(x) => Self::Owned(::owned_from_raw(x)), + } + } +} + +/// Attempts to convert the pointer type. Will fail if the pointer is not +/// suitably aligned for the target type. +fn try_cast_ptr(ptr: *const T) -> Result<*const U, PodCastError> { + if align_of::() >= align_of::() || ptr as usize % align_of::() == 0 { + Ok(ptr.cast()) + } else { + Err(PodCastError::AlignmentMismatch) + } +} +/// Attempts to convert the pointer type. Will fail if the pointer is not +/// suitably aligned for the target type. +fn try_cast_mut_ptr(ptr: *mut T) -> Result<*mut U, PodCastError> { + Ok(try_cast_ptr::(ptr)? as *mut U) +} + +/// Attempts to convert the length of `[T]` to the length of `[U]`. Will fail if +/// there is no length for the target type which will occupy all the bytes of +/// the input. +fn try_cast_len(size: usize) -> Result { + if size_of::() == size_of::() { + Ok(size) + } else if size_of::() % size_of::() == 0 { + Ok(size * (size_of::() / size_of::())) + } else { + let byte_size = size * size_of::(); + if byte_size % size_of::() == 0 { + Ok(byte_size / size_of::()) + } else { + Err(PodCastError::OutputSliceWouldHaveSlop) + } + } +} + +#[test] +fn test_try_cast_len() { + assert_eq!(try_cast_len::<(), ()>(0), Ok(0)); + assert_eq!(try_cast_len::<(), ()>(1), Ok(1)); + assert_eq!(try_cast_len::<(), ()>(2), Ok(2)); + assert_eq!(try_cast_len::(0), Ok(0)); + assert_eq!(try_cast_len::(1), Ok(1)); + assert_eq!(try_cast_len::(2), Ok(2)); + assert_eq!(try_cast_len::(0), Ok(0)); + assert_eq!(try_cast_len::(2), Ok(1)); + assert_eq!(try_cast_len::(4), Ok(2)); + assert_eq!(try_cast_len::(0), Ok(0)); + assert_eq!(try_cast_len::(4), Ok(1)); + assert_eq!(try_cast_len::(8), Ok(2)); + assert_eq!(try_cast_len::(0), Ok(0)); + assert_eq!(try_cast_len::(1), Ok(1)); + assert_eq!(try_cast_len::(2), Ok(2)); + assert_eq!(try_cast_len::(0), Ok(0)); + assert_eq!(try_cast_len::(2), Ok(1)); + assert_eq!(try_cast_len::(4), Ok(2)); + assert_eq!(try_cast_len::(1), Ok(4)); + assert_eq!(try_cast_len::(2), Ok(8)); + assert_eq!(try_cast_len::(3), Ok(12)); + assert_eq!(try_cast_len::<[u8; 3], u16>(0), Ok(0)); + assert_eq!(try_cast_len::<[u8; 3], u16>(2), Ok(3)); + assert_eq!(try_cast_len::<[u8; 3], u16>(4), Ok(6)); + assert_eq!(try_cast_len::<[u8; 3], u32>(0), Ok(0)); + assert_eq!(try_cast_len::<[u8; 3], u32>(4), Ok(3)); + assert_eq!(try_cast_len::<[u8; 3], u32>(8), Ok(6)); + assert!(try_cast_len::(1).is_err()); + assert!(try_cast_len::(3).is_err()); + assert!(try_cast_len::(5).is_err()); + assert!(try_cast_len::(1).is_err()); + assert!(try_cast_len::(2).is_err()); + assert!(try_cast_len::(3).is_err()); + assert!(try_cast_len::<[u8; 3], [u8; 2]>(1).is_err()); + assert!(try_cast_len::<[u8; 3], [u8; 2]>(3).is_err()); + assert!(try_cast_len::<[u8; 3], [u8; 2]>(5).is_err()); +} + +/// A conversion from a container's raw type to a compatible container's raw +/// type. This is not required to uphold any container specific constraints +/// which would be upheld by `AssertClassContraints`. +/// +/// # Safety +/// Assuming `AssertClassContraints` succeeds, the resulting value must be safe +/// to use for the following two steps: +/// * `RawPtr::from_raw_ptr` if the input was created from `RawPtr::to_raw_ptr`. +/// * `Container::from_raw` for the resulting type's container. +/// +/// Possible constraints include, but are not limited to: +/// * Any contained pointers must point to the same location after the +/// conversion. +/// * Assuming an input pointer is aligned, the matching result pointer must be +/// suitably aligned for the target type. +/// * All bit patterns of any converted input type must be valid for the target +/// type. +/// * For any converted mutable pointers, the reverse is true as well. +/// * Any length field must be converted from it's input type to occupy the same +/// number of bytes in it's output type. +pub unsafe trait CastRaw: Copy { + /// Performs the conversion. + fn cast_raw(self) -> T; +} + +unsafe impl CastRaw<*const U> for *const T +where + T: NoUninit, + U: AnyBitPattern, +{ + fn cast_raw(self) -> *const U { + static_assert!(AssertSameSize(T, U)); + static_assert!(AssertMinAlign(T, U)); + self as *const U + } +} +unsafe impl CastRaw<*const U> for *mut T +where + T: NoUninit, + U: AnyBitPattern, +{ + fn cast_raw(self) -> *const U { + (self as *const T).cast_raw() + } +} +unsafe impl CastRaw<*mut U> for *mut T +where + T: NoUninit + AnyBitPattern, + U: NoUninit + AnyBitPattern, +{ + fn cast_raw(self) -> *mut U { + static_assert!(AssertSameSize(T, U)); + static_assert!(AssertMinAlign(T, U)); + self as *mut U + } +} + +unsafe impl CastRaw> for *const T +where + T: NoUninit, + U: AnyBitPattern, +{ + fn cast_raw(self) -> RawSlice { + static_assert!(AssertSizeMultipleOf(T, U)); + static_assert!(AssertMinAlign(T, U)); + let len = + if size_of::() == 0 { 1 } else { size_of::() / size_of::() }; + RawSlice { data: self.cast(), len } + } +} +unsafe impl CastRaw> for *mut T +where + *const T: CastRaw>, +{ + fn cast_raw(self) -> RawSlice { + (self as *const T).cast_raw() + } +} +unsafe impl CastRaw> for *mut T +where + T: NoUninit + AnyBitPattern, + U: NoUninit + AnyBitPattern, +{ + fn cast_raw(self) -> RawMutSlice { + static_assert!(AssertSizeMultipleOf(T, U)); + static_assert!(AssertMinAlign(T, U)); + let len = + if size_of::() == 0 { 1 } else { size_of::() / size_of::() }; + RawMutSlice { data: self.cast(), len } + } +} + +unsafe impl CastRaw> for RawSlice +where + T: NoUninit, + U: AnyBitPattern, +{ + fn cast_raw(self) -> RawSlice { + static_assert!(AssertSizeMultipleOf(T, U)); + static_assert!(AssertMinAlign(T, U)); + let m = + if size_of::() == 0 { 1 } else { size_of::() / size_of::() }; + RawSlice { data: self.data as *const U, len: self.len * m } + } +} +unsafe impl CastRaw> for RawMutSlice +where + RawSlice: CastRaw>, +{ + fn cast_raw(self) -> RawSlice { + RawSlice { data: self.data as *const T, len: self.len }.cast_raw() + } +} +unsafe impl CastRaw> for RawMutSlice +where + T: NoUninit + AnyBitPattern, + U: NoUninit + AnyBitPattern, +{ + fn cast_raw(self) -> RawMutSlice { + static_assert!(AssertSizeMultipleOf(T, U)); + static_assert!(AssertMinAlign(T, U)); + let m = + if size_of::() == 0 { 1 } else { size_of::() / size_of::() }; + RawMutSlice { data: self.data as *mut U, len: self.len * m } + } +} + +unsafe impl CastRaw> for Option +where + T: CastRaw, + U: Copy, +{ + fn cast_raw(self) -> Option { + self.map(|x| x.cast_raw()) + } +} + +#[cfg(feature = "extern_crate_alloc")] +unsafe impl CastRaw> for RawCow +where + T: NoUninit, + U: AnyBitPattern, +{ + fn cast_raw(self) -> RawCow { + match self { + Self::Borrowed(x) => RawCow::Borrowed(x.cast_raw()), + Self::Owned(x) => RawCow::Owned(x.reinterpret()), + } + } +} +#[cfg(feature = "extern_crate_alloc")] +unsafe impl CastRaw> for RawCow<[T]> +where + T: NoUninit, + U: AnyBitPattern, +{ + fn cast_raw(self) -> RawCow<[U]> { + static_assert!(AssertSameAlign(T, U)); + match self { + Self::Borrowed(x) => RawCow::Borrowed(x.cast_raw()), + Self::Owned(x) => RawCow::Owned(x.cast_raw()), + } + } +} + +#[cfg(feature = "extern_crate_alloc")] +unsafe impl CastRaw> for RawVec +where + T: NoUninit, + U: AnyBitPattern, +{ + fn cast_raw(self) -> RawVec { + static_assert!(AssertSizeMultipleOf(T, U)); + let m = + if size_of::() == 0 { 1 } else { size_of::() / size_of::() }; + RawVec { slice: self.slice.cast_raw(), cap: self.cap * m } + } +} + +/// An attempted conversion from a container's raw type to a compatible +/// container's raw type. This is not required to uphold any container specific +/// constraints which would be upheld by `AssertClassContraints`. +/// +/// # Safety +/// Assuming `AssertClassContraints` succeeds, the value resulting from a +/// successful conversion must be safe to use for the following two steps: +/// * `RawPtr::from_raw_ptr` if the input was created from `RawPtr::to_raw_ptr`. +/// * `Container::from_raw` for the resulting type's container. +/// +/// Possible constraints include, but are not limited to: +/// * Any contained pointers must point to the same location after the +/// conversion. +/// * Assuming an input pointer is aligned, the matching result pointer must be +/// suitably aligned for the target type. +/// * All bit patterns of any converted input type must be valid for the target +/// type. +/// * For any converted mutable pointers, the reverse is true as well. +/// * Any length field must be converted from it's input type to occupy the same +/// number of bytes in it's output type. +pub unsafe trait TryCastRaw: Copy { + /// Perform the cast. + fn try_cast_raw(self) -> Result; +} + +unsafe impl TryCastRaw<*const U> for *const T +where + T: NoUninit, + U: AnyBitPattern, +{ + fn try_cast_raw(self) -> Result<*const U, PodCastError> { + static_assert!(AssertSameSize(T, U)); + try_cast_ptr(self) + } +} +unsafe impl TryCastRaw<*const U> for *mut T +where + T: NoUninit, + U: AnyBitPattern, +{ + fn try_cast_raw(self) -> Result<*const U, PodCastError> { + (self as *const T).try_cast_raw() + } +} +unsafe impl TryCastRaw<*mut U> for *mut T +where + T: NoUninit + AnyBitPattern, + U: NoUninit + AnyBitPattern, +{ + fn try_cast_raw(self) -> Result<*mut U, PodCastError> { + static_assert!(AssertSameSize(T, U)); + try_cast_mut_ptr(self) + } +} + +unsafe impl TryCastRaw> for RawSlice +where + T: NoUninit, + U: AnyBitPattern, +{ + fn try_cast_raw(self) -> Result, PodCastError> { + Ok(RawSlice { + data: try_cast_ptr(self.data)?, + len: try_cast_len::(self.len)?, + }) + } +} +unsafe impl TryCastRaw> for RawMutSlice +where + RawSlice: TryCastRaw>, +{ + fn try_cast_raw(self) -> Result, PodCastError> { + RawSlice { data: self.data as *const T, len: self.len }.try_cast_raw() + } +} +unsafe impl TryCastRaw> for RawMutSlice +where + T: NoUninit + AnyBitPattern, + U: NoUninit + AnyBitPattern, +{ + fn try_cast_raw(self) -> Result, PodCastError> { + Ok(RawMutSlice { + data: try_cast_mut_ptr(self.data)?, + len: try_cast_len::(self.len)?, + }) + } +} + +unsafe impl TryCastRaw<*const U> for RawSlice +where + T: NoUninit, + U: AnyBitPattern, +{ + fn try_cast_raw(self) -> Result<*const U, PodCastError> { + static_assert!(AssertSizeMultipleOf(U, T)); + static_assert!(AssertNonZeroSize(T)); + if size_of::() / size_of::() == self.len { + Ok(try_cast_ptr(self.data)?) + } else { + Err(PodCastError::SizeMismatch) + } + } +} +unsafe impl TryCastRaw<*const U> for RawMutSlice +where + RawSlice: TryCastRaw<*const U>, +{ + fn try_cast_raw(self) -> Result<*const U, PodCastError> { + RawSlice { data: self.data as *const T, len: self.len }.try_cast_raw() + } +} +unsafe impl TryCastRaw<*mut U> for RawMutSlice +where + T: NoUninit + AnyBitPattern, + U: NoUninit + AnyBitPattern, +{ + fn try_cast_raw(self) -> Result<*mut U, PodCastError> { + static_assert!(AssertSizeMultipleOf(U, T)); + static_assert!(AssertNonZeroSize(T)); + if size_of::() / size_of::() == self.len { + Ok(try_cast_mut_ptr(self.data)?) + } else { + Err(PodCastError::SizeMismatch) + } + } +} + +unsafe impl TryCastRaw> for *const T +where + T: NoUninit, + U: AnyBitPattern, +{ + fn try_cast_raw(self) -> Result, PodCastError> { + static_assert!(AssertSizeMultipleOf(T, U)); + static_assert!(AssertNonZeroSize(U)); + Ok(RawSlice { + data: try_cast_ptr(self)?, + len: size_of::() / size_of::(), + }) + } +} +unsafe impl TryCastRaw> for *mut T +where + *const T: TryCastRaw>, +{ + fn try_cast_raw(self) -> Result, PodCastError> { + (self as *const T).try_cast_raw() + } +} +unsafe impl TryCastRaw> for *mut T +where + T: NoUninit + AnyBitPattern, + U: NoUninit + AnyBitPattern, +{ + fn try_cast_raw(self) -> Result, PodCastError> { + static_assert!(AssertSizeMultipleOf(T, U)); + static_assert!(AssertNonZeroSize(U)); + Ok(RawMutSlice { + data: try_cast_mut_ptr(self)?, + len: size_of::() / size_of::(), + }) + } +} + +unsafe impl TryCastRaw> for Option +where + T: TryCastRaw, + U: Copy, +{ + fn try_cast_raw(self) -> Result, PodCastError> { + match self { + Some(x) => Ok(Some(x.try_cast_raw()?)), + None => Ok(None), + } + } +} + +#[cfg(feature = "extern_crate_alloc")] +unsafe impl TryCastRaw> for RawCow +where + T: NoUninit, + U: AnyBitPattern, +{ + fn try_cast_raw(self) -> Result, PodCastError> { + Ok(match self { + Self::Borrowed(x) => RawCow::Borrowed(x.try_cast_raw()?), + Self::Owned(x) => RawCow::Owned(x.reinterpret()), + }) + } +} +#[cfg(feature = "extern_crate_alloc")] +unsafe impl TryCastRaw> for RawCow<[T]> +where + T: NoUninit, + U: AnyBitPattern, +{ + fn try_cast_raw(self) -> Result, PodCastError> { + match self { + Self::Borrowed(x) => Ok(RawCow::Borrowed(x.try_cast_raw()?)), + Self::Owned(x) if align_of::() == align_of::() => { + Ok(RawCow::Owned(x.try_cast_raw()?)) + } + Self::Owned(_) => Err(PodCastError::AlignmentMismatch), + } + } +} + +#[cfg(feature = "extern_crate_alloc")] +unsafe impl TryCastRaw> for RawVec +where + T: NoUninit, + U: AnyBitPattern, +{ + fn try_cast_raw(self) -> Result, PodCastError> { + Ok(RawVec { + slice: self.slice.try_cast_raw()?, + cap: try_cast_len::(self.cap)?, + }) + } +} + +/// Checks any constraints the container requires when casting between types. +pub trait AssertClassContraints { + const ASSERT: () = (); +} +impl AssertClassContraints for RefT {} +impl AssertClassContraints for PtrT {} +impl AssertClassContraints for NonNullT {} +impl AssertClassContraints for AtomicPtrT {} +impl AssertClassContraints for OptionT +where + C: AssertClassContraints, +{ + const ASSERT: () = C::ASSERT; +} +impl AssertClassContraints for PinT +where + C: AssertClassContraints, +{ + const ASSERT: () = C::ASSERT; +} +#[cfg(feature = "extern_crate_alloc")] +impl AssertClassContraints + for BoxT +{ + const ASSERT: () = static_assert!(AssertSameAlign(T, U)); +} +#[cfg(feature = "extern_crate_alloc")] +impl AssertClassContraints + for RcT +{ + const ASSERT: () = static_assert!(AssertSameAlign(T, U)); +} +#[cfg(feature = "extern_crate_alloc")] +impl AssertClassContraints for RcWeakT { + const ASSERT: () = static_assert!(AssertSameAlign(T, U)); +} +#[cfg(feature = "extern_crate_alloc")] +impl AssertClassContraints + for ArcT +{ + const ASSERT: () = static_assert!(AssertSameAlign(T, U)); +} +#[cfg(feature = "extern_crate_alloc")] +impl AssertClassContraints for ArcWeakT { + const ASSERT: () = static_assert!(AssertSameAlign(T, U)); +} +#[cfg(feature = "extern_crate_alloc")] +impl AssertClassContraints for VecT { + const ASSERT: () = static_assert!(AssertSameAlign(T, U)); +} +#[cfg(feature = "extern_crate_alloc")] +impl< + T: ?Sized + RawToOwned + ItemLayout, + U: ?Sized + RawToOwned + ItemLayout, + > AssertClassContraints for CowT +{ +} + +/// Safe byte-wise conversion between two values which contain some number of +/// values without allocation. This conversion should not fail for any reason. +/// +/// This supports the following conversions: +/// * `&[mut] T`/`&[mut] [T]` -> `&U`/`&[U]` +/// * `&mut T`/`&mut [T]` -> `&mut U`/`&mut [T]` +/// * `*[const|mut] T` -> `*const U` +/// * `*mut T` -> `*mut U` +/// * `NonNull`/`NonNull<[T]>` -> `NonNull`/`NonNull<[U]>` (slice version +/// requires the `non_null_slice_cast` feature) +/// * `AtomicPtr` -> `AtomicPtr` +/// * `Pin` -> `Pin` where `T` -> `U` is valid +/// * `Option` -> `Option` where `T` -> `U` is valid +/// +/// With the `extern_crate_alloc` feature the following are also supported: +/// `Box`/`Box<[T]>` -> `Box`/`Box` +/// `Rc`/`Rc<[T]>` -> `Rc`/`Rc` +/// `rc::Weak`/`rc::Weak<[T]>` -> `rc::Weak`/`rc::Weak` +/// `Arc`/`Arc<[T]>` -> `Arc`/`Arc` +/// `sync::Weak`/`sync::Weak<[T]>` -> `sync::Weak`/`sync::Weak` +/// `Vec` -> `Vec` +/// `Cow` -> `Cow` +/// `Cow<[T]>` -> `Cow` +/// +/// This requires the `unified_cast` feature to be enabled and a rust version +/// `>=1.57`. +pub trait ReinterpretInner<'a, T: 'a>: 'a + Sized { + /// Performs the conversion. + fn reinterpret_inner(self) -> T; +} +impl<'a, T: 'a, U: 'a> ReinterpretInner<'a, U> for T +where + T: Container<'a, Class = U::Class>, + U: Container<'a>, + T::Class: AssertClassContraints, + T::Raw: CastRaw, +{ + fn reinterpret_inner(self) -> U { + static_assert!(AssertClassContraints(T::Class, T::Item, U::Item)); + unsafe { U::from_raw(self.into_raw().cast_raw()) } + } +} + +/// Attempt at a safe byte-wise conversion between two values which contain some +/// number of values. This conversion may fail due runtime conditions such as +/// size and alignment errors. +/// +/// This supports the following conversions: +/// * `&[mut] T`/`&[mut] [T]` -> `&U`/`&[U]` +/// * `&mut T`/`&mut [T]` -> `&mut U`/`&mut [T]` +/// * `*[const|mut] T` -> `*const U` +/// * `*mut T` -> `*mut U` +/// * `NonNull`/`NonNull<[T]>` -> `NonNull`/`NonNull<[U]>` (slice version +/// requires the `non_null_slice_cast` feature) +/// * `AtomicPtr` -> `AtomicPtr` +/// * `Pin` -> `Pin` where `T` -> `U` is valid +/// * `Option` -> `Option` where `T` -> `U` is valid +/// +/// With the `extern_crate_alloc` feature the following are also supported: +/// `Box`/`Box<[T]>` -> `Box`/`Box` +/// `Rc`/`Rc<[T]>` -> `Rc`/`Rc` +/// `rc::Weak`/`rc::Weak<[T]>` -> `rc::Weak`/`rc::Weak` +/// `Arc`/`Arc<[T]>` -> `Arc`/`Arc` +/// `sync::Weak`/`sync::Weak<[T]>` -> `sync::Weak`/`sync::Weak` +/// `Vec` -> `Vec` +/// `Cow` -> `Cow` +/// `Cow<[T]>` -> `Cow` +/// +/// This requires the `unified_cast` feature to be enabled and a rust version +/// `>=1.57`. +pub trait TryReinterpretInner<'a, T: 'a>: 'a + Sized { + /// The type returned in the event of a conversion error. + type Error; + /// Perform the conversion. + fn try_reinterpret_inner(self) -> Result; +} +impl<'a, T: 'a, U: 'a> TryReinterpretInner<'a, U> for T +where + T: Container<'a, Class = U::Class>, + U: Container<'a>, + T::Class: AssertClassContraints, + T::Raw: TryCastRaw, +{ + type Error = >::Err; + fn try_reinterpret_inner(self) -> Result { + static_assert!(AssertNonMixedZeroSize(T::Item, U::Item)); + static_assert!(AssertClassContraints(T::Class, T::Item, U::Item)); + let raw = self.into_raw(); + match raw.try_cast_raw() { + Ok(raw) => Ok(unsafe { U::from_raw(raw) }), + Err(e) => Err( + >::cast_error_with_value(e, unsafe { + T::from_raw(raw) + }), + ), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 7ee1f90..ca01fc4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -88,6 +88,10 @@ macro_rules! impl_unsafe_marker_for_simd { #[cfg(feature = "extern_crate_std")] extern crate std; +#[cfg(feature = "unified_cast")] +#[macro_use] +mod static_assert; + #[cfg(feature = "extern_crate_alloc")] extern crate alloc; #[cfg(feature = "extern_crate_alloc")] @@ -125,6 +129,13 @@ pub use offset_of::*; mod transparent; pub use transparent::*; +#[cfg(feature = "unified_cast")] +mod cast; +#[cfg(feature = "unified_cast")] +pub use cast::{ + Reinterpret, ReinterpretInner, TryReinterpret, TryReinterpretInner, +}; + #[cfg(feature = "derive")] pub use bytemuck_derive::{ AnyBitPattern, ByteEq, ByteHash, CheckedBitPattern, Contiguous, NoUninit, diff --git a/src/static_assert.rs b/src/static_assert.rs new file mode 100644 index 0000000..a181435 --- /dev/null +++ b/src/static_assert.rs @@ -0,0 +1,105 @@ +//! Various type assertions used to raise compile errors. These are implemented +//! as panics during the evaluation of associated constants which will be +//! converted to compile time errors when the evaluation is forced. +//! +//! These can be used through the helper macro. e.g. +//! ```rust,ignore +//! static_assert!(AssertSameSize(u32, i32)); +//! ``` + +use core::mem::{align_of, size_of}; + +pub trait ItemLayout { + const SIZE: usize; + const ALIGN: usize; +} +impl ItemLayout for T { + const SIZE: usize = size_of::(); + const ALIGN: usize = align_of::(); +} +impl ItemLayout for [T] { + const SIZE: usize = size_of::(); + const ALIGN: usize = align_of::(); +} +impl ItemLayout for str { + const SIZE: usize = 1; + const ALIGN: usize = 1; +} + +pub trait AssertNonMixedZeroSize: ItemLayout { + const ASSERT: () = { + if (Self::SIZE == 0) != (T::SIZE == 0) { + panic!( + "Attempt to cast between a zero-sized type and a non-zero-sized type" + ); + } + }; +} +impl AssertNonMixedZeroSize + for T +{ +} + +pub trait AssertNonZeroSize: ItemLayout { + const ASSERT: () = { + if Self::SIZE == 0 { + panic!("Attempt to cast a zero-sized type"); + } + }; +} +impl AssertNonZeroSize for T {} + +pub trait AssertSameSize: ItemLayout { + const ASSERT: () = { + if Self::SIZE != T::SIZE { + panic!("Attempt to cast between two types with different sizes"); + } + }; +} +impl AssertSameSize for T {} + +pub trait AssertMaxSize: ItemLayout { + const ASSERT: () = { + if Self::SIZE > T::SIZE { + panic!("Attempt to cast to a type of a smaller size"); + } + }; +} +impl AssertMaxSize for T {} + +pub trait AssertSizeMultipleOf: ItemLayout { + const ASSERT: () = { + if Self::SIZE != T::SIZE && Self::SIZE % T::SIZE != 0 { + panic!("Attempt to cast from a type which is not a multiple of the target's size"); + } + }; +} +impl AssertSizeMultipleOf + for T +{ +} + +pub trait AssertSameAlign: ItemLayout { + const ASSERT: () = { + if Self::ALIGN != T::ALIGN { + panic!("Attempt to cast between two types with different alignments"); + } + }; +} +impl AssertSameAlign for T {} + +pub trait AssertMinAlign: ItemLayout { + const ASSERT: () = { + if Self::ALIGN < T::ALIGN { + panic!("Attempt to cast to a type with a larger alignment"); + } + }; +} +impl AssertMinAlign for T {} + +macro_rules! static_assert { + ($assertion:ident($ty:ty $(, $($args:tt)*)?)) => {{ + #[allow(path_statements)] + { <$ty as $assertion<$($($args)*)?>>::ASSERT; }; + }} +} diff --git a/tests/unified_cast_tests.rs b/tests/unified_cast_tests.rs new file mode 100644 index 0000000..fe2c086 --- /dev/null +++ b/tests/unified_cast_tests.rs @@ -0,0 +1,369 @@ +#![cfg(feature = "unified_cast")] + +use bytemuck::{ + PodCastError::{AlignmentMismatch, OutputSliceWouldHaveSlop}, + Reinterpret, ReinterpretInner, TryReinterpret, TryReinterpretInner, +}; +use core::{ + convert::identity, + num::NonZeroI32, + pin::Pin, + ptr::{self, NonNull}, + sync::atomic::AtomicPtr, +}; + +macro_rules! test_assert { + ($target:expr, $init:expr) => { + assert_eq!($target, $init); + }; + ($target:expr, $_init:expr, $e:expr) => { + assert_eq!($target, $e); + }; +} + +// Test both `Reinterpret` and `TryReinterpret` +macro_rules! test_reinterpret { + ( + $init:expr => $ty:ty $(= $assert_val:expr)? + ) => {{ + test_assert!(identity::<$ty>($init.reinterpret()), $init $(, $assert_val)?); + test_assert!(identity::<$ty>($init.try_reinterpret().unwrap()), $init $(, $assert_val)?); + }}; +} + +// Test both `TryReinterpret` +macro_rules! test_try_reinterpret { + ( + $init:expr => $ty:ty $(= $assert_val:expr)? + ) => {{ + test_assert!(identity::>($init.try_reinterpret()), Some($init) $(, $assert_val)?); + }}; +} + +// Test both `ReinterpretInner` and `TryReinterpretInner` +macro_rules! test_reinterpret_inner { + ( + $init:expr => $ty:ty $(= $assert_val:expr)? + ) => {{ + test_assert!(identity::<$ty>($init.reinterpret_inner()), $init $(, $assert_val)?); + test_assert!(identity::<$ty>($init.try_reinterpret_inner().unwrap()), $init $(, $assert_val)?); + }}; +} + +// Test both `ReinterpretInner` and `TryReinterpretInner` +macro_rules! test_try_reinterpret_inner { + ( + $init:expr => $ty:ty $(= $assert_val:expr)? + ) => {{ + test_assert!(identity::>($init.try_reinterpret_inner()), $init $(, $assert_val)?); + }}; +} + +#[test] +fn reinterpret_self() { + test_reinterpret!(0u8 => u8); + test_reinterpret!(() => ()); + test_reinterpret!([0i32, 1i32, 2i32] => [i32; 3]); +} + +#[test] +fn reinterpret_same_align() { + test_reinterpret!(1u8 => i8 = 1i8); + test_reinterpret!( + [u32::MAX, 0, 1] => [i32; 3] + = [i32::from_ne_bytes(u32::MAX.to_ne_bytes()), 0, 1] + ); + test_reinterpret!(0u32 => Option = None); +} + +#[test] +fn reinterpret_lesser_align() { + test_reinterpret!(0u16 => [u8; 2] = [0u8; 2]); + test_reinterpret!([0u64; 2] => [u16; 8] = [0u16; 8]); +} + +#[test] +fn reinterpret_greater_align() { + test_reinterpret!([0u16; 2] => u32 = 0); +} + +#[test] +fn reinterpret_no_uninit() { + test_reinterpret!(true => u8 = 1); +} + +#[test] +fn try_reinterpret() { + test_try_reinterpret!(0u8 => bool = Some(false)); + test_try_reinterpret!(2u8 => bool = None); +} + +#[test] +fn reinterpret_inner_self() { + test_reinterpret_inner!(&() => &()); + test_reinterpret_inner!([0u32; 2].as_slice() => &[u32]); + test_reinterpret_inner!(&mut 1u32 => &mut u32); + test_reinterpret_inner!([1i32, 4i32].as_mut_slice() => &mut [i32]); + + let x = &[0u8; 2] as *const [u8; 2]; + test_reinterpret_inner!(x => *const [u8; 2]); + let x = &mut [0u8; 2] as *mut [u8; 2]; + test_reinterpret_inner!(x => *mut [u8; 2]); + let x = NonNull::from(&5u64); + test_reinterpret_inner!(x => NonNull); + + test_reinterpret_inner!(Some(&0u8) => Option<&u8>); + test_reinterpret_inner!(Some(Some([0u8; 4].as_slice())) => Option>); + test_reinterpret_inner!(Option::<&u8>::None => Option<&u8>); + + let x = Some(NonNull::from(&0i32)); + test_reinterpret_inner!(x => Option>); + + let _: AtomicPtr = AtomicPtr::new(&mut 1i8).reinterpret_inner(); + let _: AtomicPtr = + AtomicPtr::new(&mut 1i8).try_reinterpret_inner().unwrap(); + + test_reinterpret_inner!(Pin::new(&1u8) => Pin<&u8>); + test_reinterpret_inner!(Pin::new([0u16; 2].as_slice()) => Pin<&[u16]>); +} + +#[test] +fn reinterpret_inner_same_align() { + test_reinterpret_inner!(&50u8 => &i8 = &50); + test_reinterpret_inner!([0u32; 2].as_slice() => &[i32] = [0; 2]); + test_reinterpret_inner!( + &mut 1f32 => &mut u32 + = &mut u32::from_ne_bytes(1f32.to_ne_bytes()) + ); + test_reinterpret_inner!([1i32, 4i32].as_mut_slice() => &mut [u32] = [1u32, 4u32]); + + let x = &[0u8; 2] as *const [u8; 2]; + test_reinterpret_inner!(x => *const [i8; 2] = x.cast()); + let x = &mut [0u8; 2] as *mut [u8; 2]; + test_reinterpret_inner!(x => *mut [i8; 2] = x.cast()); + let x = NonNull::from(&5u64); + test_reinterpret_inner!(x => NonNull = x.cast()); + + test_reinterpret_inner!(Some(&0u8) => Option<&i8> = Some(&0i8)); + test_reinterpret_inner!( + Some(Some([127u8; 4].as_slice())) => Option> + = Some(Some([127i8; 4].as_slice())) + ); + test_reinterpret_inner!(Option::<&u8>::None => Option<&u8> = None); + + let x = Some(NonNull::from(&0i32)); + test_reinterpret_inner!(x => Option> = x.map(|x| x.cast())); + + let _: AtomicPtr = AtomicPtr::new(&mut 1i8).reinterpret_inner(); + let _: AtomicPtr = + AtomicPtr::new(&mut 1i8).try_reinterpret_inner().unwrap(); + + test_reinterpret_inner!(Pin::new(&1u8) => Pin<&i8> = Pin::new(&1i8)); + test_reinterpret_inner!( + Pin::new([0xFFFFu16; 2].as_slice()) => Pin<&[i16]> + = Pin::new([i16::from_ne_bytes(0xFFFFu16.to_ne_bytes()); 2].as_slice()) + ); +} + +#[test] +fn reinterpret_inner_lesser_align() { + test_reinterpret_inner!(&0xFF01u16 => &[u8; 2] = &0xFF01u16.to_ne_bytes()); + test_reinterpret_inner!([0u32; 2].as_slice() => &[[u16; 2]] = [[0; 2]; 2]); + test_reinterpret_inner!(&mut 1f32 => &mut [u8; 4] = &1f32.to_ne_bytes()); + test_reinterpret_inner!( + [1i32, 4i32].as_mut_slice() => &mut [[u8; 4]] + = &[1u32.to_ne_bytes(), 4u32.to_ne_bytes()] + ); + + let x = &[0u64; 2] as *const [u64; 2]; + test_reinterpret_inner!(x => *const [[u32; 2]; 2] = x.cast()); + let x = &mut 0u16 as *mut u16; + test_reinterpret_inner!(x => *mut [i8; 2] = x.cast()); + let x = NonNull::from(&5u64); + test_reinterpret_inner!(x => NonNull<[f32; 2]> = x.cast()); + + test_reinterpret_inner!(Some(&0u32) => Option<&[u8; 4]> = Some(&[0; 4])); + test_reinterpret_inner!( + Some(Some([127u16; 2].as_slice())) => Option> + = Some(Some([127u16.to_ne_bytes(); 2].as_slice())) + ); + test_reinterpret_inner!(Option::<&u16>::None => Option<&[u8; 2]> = None); + + let x = Some(NonNull::from(&0i32)); + test_reinterpret_inner!(x => Option> = x.map(|x| x.cast())); + + let _: AtomicPtr<[u8; 2]> = AtomicPtr::new(&mut 0u16).reinterpret_inner(); + let _: AtomicPtr<[u8; 2]> = + AtomicPtr::new(&mut 0u16).try_reinterpret_inner().unwrap(); + + test_reinterpret_inner!( + Pin::new(&1u16) => Pin<&[u8; 2]> + = Pin::new(&1u16.to_ne_bytes()) + ); + test_reinterpret_inner!( + Pin::new([0xFF00u16; 2].as_slice()) => Pin<&[[u8; 2]]> + = Pin::new([0xFF00u16.to_ne_bytes(); 2].as_slice()) + ); +} + +#[test] +fn reinterpret_inner_no_uninit() { + test_reinterpret_inner!(&true => &u8 = &1); + test_reinterpret_inner!([true, false].as_slice() => &[u8] = [1, 0]); + let x = &true as *const bool; + test_reinterpret_inner!(x => *const u8 = x.cast()); + test_reinterpret_inner!(Some(&true) => Option<&u8> = Some(&1)); + test_reinterpret_inner!(Some(Some(&true)) => Option> = Some(Some(&1))); + test_reinterpret_inner!(Pin::new(&true) => Pin<&u8> = Pin::new(&1)); +} + +#[test] +fn reinterpret_inner_unsize() { + test_reinterpret_inner!(&0u32 => &[u8] = [0; 4]); + test_reinterpret_inner!(&0u32 => &[u16] = [0; 2]); + test_reinterpret_inner!(&mut 0xFFEEDDCCu32 => &mut [u8] = 0xFFEEDDCCu32.to_ne_bytes()); + test_reinterpret_inner!(&mut 0u32 => &mut [u16] = [0; 2]); + #[cfg(feature = "non_null_slice_cast")] + { + let x = NonNull::from(&mut 0u32); + test_reinterpret_inner!( + x => NonNull<[u8]> + = NonNull::new(ptr::slice_from_raw_parts_mut(x.as_ptr().cast(), 4)).unwrap() + ); + test_reinterpret_inner!( + x => NonNull<[u16]> + = NonNull::new(ptr::slice_from_raw_parts_mut(x.as_ptr().cast(), 2)).unwrap() + ); + } +} + +#[test] +fn try_reinterpret_inner_misaligned() { + let x = [0u32, 0u32]; + let x: &[u16; 4] = (&x).reinterpret_inner(); + let x: &[u16; 2] = x[1..3].try_reinterpret_inner().unwrap(); + test_try_reinterpret_inner!(x => &u32 = Err(AlignmentMismatch)); + test_try_reinterpret_inner!(x.as_slice() => &[u32] = Err(AlignmentMismatch)); + test_try_reinterpret_inner!(Pin::new(x) => Pin<&u32> = Err(AlignmentMismatch)); + test_try_reinterpret_inner!(x as *const [u16; 2] => *const u32 = Err(AlignmentMismatch)); + test_try_reinterpret_inner!(Some(x) => Option<&u32> = Err(AlignmentMismatch)); + + let mut x = [0u32, 0u32]; + let x: &mut [u16; 4] = (&mut x).reinterpret_inner(); + let x: &mut [u16; 2] = (&mut x[1..3]).try_reinterpret_inner().unwrap(); + test_try_reinterpret_inner!(x => &mut u32 = Err(AlignmentMismatch)); + test_try_reinterpret_inner!(x.as_mut_slice() => &mut [u32] = Err(AlignmentMismatch)); + test_try_reinterpret_inner!(Pin::new(&mut *x) => Pin<&mut u32> = Err(AlignmentMismatch)); + test_try_reinterpret_inner!(x as *mut [u16; 2] => *mut u32 = Err(AlignmentMismatch)); + test_try_reinterpret_inner!(Some(&mut *x) => Option<&mut u32> = Err(AlignmentMismatch)); + test_try_reinterpret_inner!(NonNull::from(&mut *x) => NonNull = Err(AlignmentMismatch)); + let err: Result, _> = + AtomicPtr::new(&mut *x).try_reinterpret_inner(); + assert!(matches!(err, Err(AlignmentMismatch))); + #[cfg(feature = "non_null_slice_cast")] + { + test_try_reinterpret_inner!(NonNull::from(x.as_mut_slice()) => NonNull<[u32]> = Err(AlignmentMismatch)); + } +} + +#[test] +fn try_reinterpret_inner_greater_align() { + let x = 0u32; + let y: &[u16; 2] = (&x).reinterpret_inner(); + test_try_reinterpret_inner!(y => &u32 = Ok(&x)); + test_try_reinterpret_inner!(y.as_slice() => &[u32] = Ok([x].as_slice())); + test_try_reinterpret_inner!(Pin::new(y) => Pin<&u32> = Ok(Pin::new(&x))); + test_try_reinterpret_inner!(y as *const [u16; 2] => *const u32 = Ok(&x as *const u32)); + test_try_reinterpret_inner!(NonNull::from(y) => NonNull = Ok(NonNull::from(&x))); + test_try_reinterpret_inner!(Some(y) => Option<&u32> = Ok(Some(&x))); + + let mut x = 0u32; + let ptr = &mut x as *mut u32; + let y: &mut [u16; 2] = (&mut x).reinterpret_inner(); + test_try_reinterpret_inner!(y => &mut u32 = Ok(&mut 0)); + test_try_reinterpret_inner!(y.as_mut_slice() => &mut [u32] = Ok([0].as_mut_slice())); + test_try_reinterpret_inner!(Pin::new(&mut *y) => Pin<&mut u32> = Ok(Pin::new(&mut 0))); + test_try_reinterpret_inner!(y as *mut [u16; 2] => *mut u32 = Ok(ptr)); + test_try_reinterpret_inner!(Some(&mut *y) => Option<&mut u32> = Ok(Some(&mut 0))); + let _: AtomicPtr = + AtomicPtr::new(&mut *y).try_reinterpret_inner().unwrap(); + #[cfg(feature = "non_null_slice_cast")] + { + test_try_reinterpret_inner!( + NonNull::from(y.as_mut_slice()) => NonNull<[u32]> + = Ok(NonNull::new(ptr::slice_from_raw_parts_mut(ptr, 1)).unwrap()) + ); + } +} + +#[test] +fn try_reinterpret_change_element_size() { + test_try_reinterpret_inner!( + [0u32; 3].as_slice() => &[[u8; 3]] + = Ok([[0; 3]; 4].as_slice()) + ); + test_try_reinterpret_inner!( + [0u32; 1].as_slice() => &[[u8; 2]] + = Ok([[0; 2]; 2].as_slice()) + ); + test_try_reinterpret_inner!( + [0u32; 3].as_mut_slice() => &mut [[u8; 3]] + = Ok([[0; 3]; 4].as_mut_slice()) + ); + test_try_reinterpret_inner!( + [0u32; 1].as_mut_slice() => &mut [[u8; 2]] + = Ok([[0; 2]; 2].as_mut_slice()) + ); + test_try_reinterpret_inner!( + Some([0u32; 3].as_slice()) => Option<&[[u8; 3]]> + = Ok(Some([[0; 3]; 4].as_slice())) + ); + test_try_reinterpret_inner!( + Pin::new([0u32; 1].as_slice()) => Pin<&[[u8; 2]]> + = Ok(Pin::new([[0; 2]; 2].as_slice())) + ); + #[cfg(feature = "non_null_slice_cast")] + { + let mut x = [0u32; 3]; + test_try_reinterpret_inner!( + NonNull::from(x.as_mut_slice()) => NonNull<[[u8; 3]]> + = Ok(NonNull::new(ptr::slice_from_raw_parts_mut(x.as_mut_ptr().cast::<[u8; 3]>(), 4)).unwrap()) + ); + } +} + +#[test] +fn try_reinterpret_wrong_element_size() { + test_try_reinterpret_inner!( + [0u32; 3].as_slice() => &[[u8; 5]] + = Err(OutputSliceWouldHaveSlop) + ); + test_try_reinterpret_inner!( + [0u32; 1].as_slice() => &[[u32; 2]] + = Err(OutputSliceWouldHaveSlop) + ); + test_try_reinterpret_inner!( + [0u32; 3].as_mut_slice() => &mut [[u8; 5]] + = Err(OutputSliceWouldHaveSlop) + ); + test_try_reinterpret_inner!( + [0u32; 1].as_mut_slice() => &mut [[u32; 2]] + = Err(OutputSliceWouldHaveSlop) + ); + test_try_reinterpret_inner!( + Some([0u32; 3].as_slice()) => Option<&[[u8; 5]]> + = Err(OutputSliceWouldHaveSlop) + ); + test_try_reinterpret_inner!( + Pin::new([0u32; 1].as_slice()) => Pin<&[[u32; 2]]> + = Err(OutputSliceWouldHaveSlop) + ); + #[cfg(feature = "non_null_slice_cast")] + { + let mut x = [0u32; 3]; + test_try_reinterpret_inner!( + NonNull::from(x.as_mut_slice()) => NonNull<[[u8; 5]]> + = Err(OutputSliceWouldHaveSlop) + ); + } +}