Skip to content

Commit

Permalink
Add KnownLayout::validate_cast
Browse files Browse the repository at this point in the history
TODO:
- Tests
- Do we need to think about `isize` overflow in order to prevent user
  code from being unsound (ie, generating invalid calls to
  `<*const _>::add`)? My current thinking is that we don't - it's
  already invalid to construct a memory region that overflows `isize`,
  and so it'd never be valid to call `validate_cast` in that case.

Co-authored-by: Jack Wrenn <jswrenn@amazon.com>
  • Loading branch information
joshlf and jswrenn committed Sep 9, 2023
1 parent 12e7fac commit 90bee71
Show file tree
Hide file tree
Showing 2 changed files with 339 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ optional = true
zerocopy-derive = { version = "=0.7.3", path = "zerocopy-derive" }

[dev-dependencies]
assert_matches = "1.5"
itertools = "0.11"
rand = { version = "0.8.5", features = ["small_rng"] }
rustversion = "1.0"
static_assertions = "1.1"
Expand Down
337 changes: 337 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,21 @@ pub struct DstLayout {
/// `size_of::<T>()`. For DSTs, the size represents the size of the type
/// when the trailing slice field contains 0 elements.
/// - For all types, the alignment represents the alignment of the type.
// TODO: If we end up replacing this with separate size and alignment to
// make Kani happy, file an issue to eventually adopt the stdlib's
// `Alignment` type trick.
_base_layout: Layout,
/// For sized types, `None`. For DSTs, the size of the element type of the
/// trailing slice.
_trailing_slice_elem_size: Option<usize>,
}

#[cfg_attr(test, derive(Copy, Clone, Debug))]
enum _CastType {
_Prefix,
_Suffix,
}

impl DstLayout {
/// Constructs a `DstLayout` which describes `T`.
///
Expand All @@ -251,6 +260,202 @@ impl DstLayout {
_trailing_slice_elem_size: Some(mem::size_of::<T>()),
}
}

/// Validates that a cast is sound from a layout perspective.
///
/// Validates that the size and alignment requirements of a type with the
/// layout described in `self` would not be violated by performing a
/// `cast_type` cast from a pointer with address `addr` which refers to a
/// memory region of size `bytes_len`.
///
/// If the cast is valid, `validate_cast` returns `(elems, split_at)`. If
/// `self` describes a dynamically-sized type, then `elems` is the maximum
/// number of trailing slice elements for which a cast would be valid (for
/// sized types, `elem` is meaningless and should be ignored). `split_at` is
/// the index at which to split the memory region in order for the prefix
/// (suffix) to contain the result of the cast, and in order for the
/// remaining suffix (prefix) to contain the leftover bytes.
///
/// There are three conditions under which a cast can fail:
/// - The smallest possible value for the type is larger than the provided
/// memory region
/// - A prefix cast is requested, and `addr` does not satisfy `self`'s
/// alignment requirement
/// - A suffix cast is requested, and `addr + bytes_len` does not satisfy
/// `self`'s alignment requirement (as a consequence, since the size of
/// the trailing slice element is a multiple of the alignment, no length
/// for the trailing slice will result in a starting address which is
/// properly aligned)
///
/// # Safety
///
/// The caller may assume that this implementation is correct, and may rely
/// on that assumption for the soundness of their code. In particular, the
/// caller may assume that:
/// - A pointer to the type (for dynamically sized types, this includes
/// `elems` as its pointer metadata) describes an object of size `size <=
/// bytes_len`
/// - If this is a prefix cast, `addr` satisfies `self`'s alignment
/// - If this is a suffix cast, `addr + bytes_len - size` satisfies `self`'s
/// alignment
///
/// # Panics
///
/// If `addr + bytes_len` overflows `usize`, `validate_cast` may panic, or
/// it may return incorrect results. No guarantees are made about when
/// `validate_cast` will panic. The caller should not rely on
/// `validate_cast` panicking in any particular condition, even if
/// `debug_assertions` are enabled.
const fn _validate_cast(
&self,
addr: usize,
bytes_len: usize,
cast_type: _CastType,
) -> Option<(usize, usize)> {
// `debug_assert!`, but with `#[allow(clippy::arithmetic_side_effects)]`.
macro_rules! __debug_assert {
($e:expr $(, $msg:expr)?) => {
debug_assert!({
#[allow(clippy::arithmetic_side_effects)]
let e = $e;
e
} $(, $msg)?);
};
}

// Note that, in practice, `elem_size` is always a compile-time
// constant. We do this check earlier than needed to ensure that we
// always panic as a result of bugs in the program (such as calling this
// function on an invalid type) instead of allowing this panic to be
// hidden if the cast would have failed anyway for runtime reasons (such
// as a too-small memory region).
//
// TODO(#67): Once our MSRV is 1.65, use let-else:
// https://blog.rust-lang.org/2022/11/03/Rust-1.65.0.html#let-else-statements
let elem_size = match self._trailing_slice_elem_size {
Some(elem_size) => match NonZeroUsize::new(elem_size) {
Some(elem_size) => Some(elem_size),
None => panic!("attempted to cast to slice type with zero-sized element"),
},
None => None,
};

// Precondition
__debug_assert!(addr.checked_add(bytes_len).is_some(), "`addr` + `bytes_len` > usize::MAX");

// We check alignment for `addr` (for prefix casts) or `addr +
// bytes_len` (for suffix casts). For a prefix cast, the correctness of
// this check is trivial - `addr` is the address the object will live
// at.
//
// For a suffix cast, we know that all valid sizes for the type are a
// multiple of the alignment. Thus, a validly-sized instance which lives
// at a validly-aligned address must also end at a validly-aligned
// address. Thus, if the end address for a suffix cast (`addr +
// bytes_len`) is not aligned, then no valid start address will be
// aligned either.
let offset = match cast_type {
_CastType::_Prefix => 0,
_CastType::_Suffix => bytes_len,
};

// Addition is guaranteed not to overflow because `offset <= bytes_len`,
// and `addr + bytes_len <= usize::MAX` is a precondition of this
// method. Modulus is guaranteed not to divide by 0 because `.align()`
// guarantees that its return value is non-zero.
#[allow(clippy::arithmetic_side_effects)]
if (addr + offset) % self._base_layout.align() != 0 {
return None;
}

let base_size = self._base_layout.size();

// LEMMA 0: max_slice_bytes + base_size == bytes_len
//
// LEMMA 1: base_size <= bytes_len:
// - If `base_size > bytes_len`, `bytes_len.checked_sub(base_size)`
// returns `None`, and we return.
//
// TODO(#67): Once our MSRV is 1.65, use let-else:
// https://blog.rust-lang.org/2022/11/03/Rust-1.65.0.html#let-else-statements
let max_slice_bytes = if let Some(max_byte_slice) = bytes_len.checked_sub(base_size) {
max_byte_slice
} else {
return None;
};

// Lemma 0
__debug_assert!(max_slice_bytes + base_size == bytes_len);

// Lemma 1
__debug_assert!(base_size <= bytes_len);

let (elems, self_bytes) = if let Some(elem_size) = elem_size {
// Guaranteed not to divide by 0 because `elem_size` is a
// `NonZeroUsize`.
#[allow(clippy::arithmetic_side_effects)]
let elems = max_slice_bytes / elem_size.get();

// NOTE: Another option for this step in the algorithm is to set
// `slice_bytes = elems * elem_size`. However, using multiplication
// causes Kani to choke. In practice, the compiler is likely to
// generate identical machine code in both cases. Note that this
// divide-then-mod approach is trivially optimizable into a single
// operation that computes both the quotient and the remainder.

// First line is guaranteed not to mod by 0 because `elem_size` is a
// `NonZeroUsize`. Second line is guaranteed not to underflow
// because `rem <= max_slice_bytes` thanks to the mod operation.
//
// LEMMA 2: slice_bytes <= max_slice_bytes
#[allow(clippy::arithmetic_side_effects)]
let rem = max_slice_bytes % elem_size.get();
#[allow(clippy::arithmetic_side_effects)]
let slice_bytes = max_slice_bytes - rem;

// Lemma 2
__debug_assert!(slice_bytes <= max_slice_bytes);

// Guaranteed not to overflow:
// - max_slice_bytes + base_size == bytes_len (lemma 0)
// - slice_bytes <= max_slice_bytes (lemma 2)
// - slice_bytes + base_size <= bytes_len (substitution) ------+
// - bytes_len <= usize::MAX (bytes_len: usize) |
// - slice_bytes + base_size <= usize::MAX (substitution) |
// |
// LEMMA 3: self_bytes <= bytes_len: |
// - slice_bytes + base_size <= bytes_len <--------------------------+ (reused for lemma)
// - slice_bytes <= bytes_len
#[allow(clippy::arithmetic_side_effects)]
let self_bytes = base_size + slice_bytes;

// Lemma 3
__debug_assert!(self_bytes <= bytes_len);

(elems, self_bytes)
} else {
(0, base_size)
};

// LEMMA 4: self_bytes <= bytes_len:
// - `if` branch returns `self_bytes`; lemma 3 guarantees `self_bytes <=
// bytes_len`
// - `else` branch returns `base_size`; lemma 1 guarantees `base_size <=
// bytes_len`

// Lemma 4
__debug_assert!(self_bytes <= bytes_len);

let split_at = match cast_type {
_CastType::_Prefix => self_bytes,
// Guaranteed not to underflow because `self_bytes <= bytes_len`
// (lemma 4).
#[allow(clippy::arithmetic_side_effects)]
_CastType::_Suffix => bytes_len - self_bytes,
};

Some((elems, split_at))
}
}

/// A trait which carries information about a type's layout that is used by the
Expand Down Expand Up @@ -2738,6 +2943,138 @@ mod tests {
}
}

// This test takes a long time when running under Miri, so we skip it in
// that case. This is acceptable because this is a logic test that doesn't
// attempt to expose UB.
#[test]
#[cfg_attr(miri, ignore)]
fn test_validate_cast() {
fn layout(
base_size: usize,
align: usize,
_trailing_slice_elem_size: Option<usize>,
) -> DstLayout {
DstLayout {
_base_layout: Layout::from_size_align(base_size, align).unwrap(),
_trailing_slice_elem_size,
}
}

/// This macro accepts arguments in the form of:
///
/// layout(_, _, _).validate_cast(_, _, _), Ok(Some((_, _)))
/// | | | | | | | |
/// base_size ----+ | | | | | | |
/// align -----------+ | | | | | |
/// trailing_size ------+ | | | | |
/// addr --------------------------------+ | | | |
/// bytes_len ------------------------------+ | | |
/// cast_type ---------------------------------+ | |
/// elems --------------------------------------------------+ |
/// split_at --------------------------------------------------+
///
/// Each argument can either be an iterator or a wildcard. Each
/// wildcarded variable is implicitly replaced by an iterator over a
/// representative sample of values for that variable. Each `test!`
/// invocation iterates over every combination of values provided by
/// each variable's iterator (ie, the cartesian product) and validates
/// that the results are expected.
///
/// The final argument uses the same syntax, but it has a different
/// meaning:
/// - If it is `Ok(pat)`, then the pattern `pat` is supplied to
/// `assert_matches!` to validate the computed result for each
/// combination of input values.
/// - If it is `Err(msg)`, then `test!` validates that the call to
/// `validate_cast` panics with the given panic message.
///
/// Note that the meta-variables that match these variables have the
/// `tt` type, and some valid expressions are not valid `tt`s (such as
/// `a..b`). In this case, wrap the expression in parentheses, and it
/// will become valid `tt`.
macro_rules! test {
(
layout($base_size:tt, $align:tt, $trailing_size:tt)
.validate_cast($addr:tt, $bytes_len:tt, $cast_type:tt), $expect:pat $(,)?
) => {
itertools::iproduct!(
test!(@generate_usize $base_size),
test!(@generate_align $align),
test!(@generate_opt_usize $trailing_size),
test!(@generate_usize $addr),
test!(@generate_usize $bytes_len),
test!(@generate_cast_type $cast_type)
).for_each(|(base_size, align, trailing_size, addr, bytes_len, cast_type)| {
let actual = std::panic::catch_unwind(|| {
layout(base_size, align, trailing_size)._validate_cast(addr, bytes_len, cast_type)
}).map_err(|d| {
*d.downcast::<&'static str>().expect("expected string panic message").as_ref()
});
assert_matches::assert_matches!(
actual, $expect,
"layout({base_size}, {align}, {trailing_size:?}).validate_cast({addr}, {bytes_len}, {cast_type:?})",
);
});
};
(@generate_usize _) => { 0..8 };
(@generate_align _) => { [1, 2, 4, 8, 16] };
(@generate_opt_usize _) => { [None].into_iter().chain((0..8).map(Some).into_iter()) };
(@generate_cast_type _) => { [_CastType::_Prefix, _CastType::_Suffix] };
(@generate_cast_type $variant:ident) => { [_CastType::$variant] };
// Some expressions need to be wrapped in parentheses in order to be
// valid `tt`s (required by the top match pattern). See the comment
// below for more details. This arm removes these parentheses to
// avoid generating an `unused_parens` warning.
(@$_:ident ($vals:expr)) => { $vals };
(@$_:ident $vals:expr) => { $vals };
}

const EVENS: [usize; 5] = [0, 2, 4, 6, 8];
const NZ_EVENS: [usize; 5] = [2, 4, 6, 8, 10];
const ODDS: [usize; 5] = [1, 3, 5, 7, 9];

// base_size is too big for the memory region.
test!(layout((1..8), _, ((1..8).map(Some))).validate_cast(_, [0], _), Ok(None));
test!(layout((2..8), _, ((1..8).map(Some))).validate_cast(_, [1], _), Ok(None));

// addr is unaligned for prefix cast
test!(layout(_, [2], [None]).validate_cast(ODDS, _, _Prefix), Ok(None));
test!(layout(_, [2], (NZ_EVENS.map(Some))).validate_cast(ODDS, _, _Prefix), Ok(None));

// addr is aligned, but end of buffer is unaligned for suffix cast
test!(layout(_, [2], [None]).validate_cast(EVENS, ODDS, _Suffix), Ok(None));
test!(layout(_, [2], (NZ_EVENS.map(Some))).validate_cast(EVENS, ODDS, _Suffix), Ok(None));

// TDOO: Success cases

// Unfortunately, these constants cannot easily be used in the
// implementation of `validate_cast`, since `panic!` consumes a string
// literal, not an expression.
mod msgs {
pub(super) const TRAILING: &str =
"attempted to cast to slice type with zero-sized element";
pub(super) const OVERFLOW: &str = "`addr` + `bytes_len` > usize::MAX";
}

// casts with ZST trailing element types are unsupported
test!(layout(_, _, [Some(0)]).validate_cast(_, _, _), Err(msgs::TRAILING),);

// addr + bytes_len must not overflow usize
test!(
layout(_, [1], (NZ_EVENS.map(Some))).validate_cast([usize::MAX], (1..100), _),
Err(msgs::OVERFLOW)
);
test!(layout(_, [1], [None]).validate_cast((1..100), [usize::MAX], _), Err(msgs::OVERFLOW));
test!(
layout([1], [1], [None]).validate_cast(
[usize::MAX / 2 + 1, usize::MAX],
[usize::MAX / 2 + 1, usize::MAX],
_
),
Err(msgs::OVERFLOW)
);
}

#[test]
fn test_known_layout() {
// Test that `$ty` and `ManuallyDrop<$ty>` have the expected layout.
Expand Down

0 comments on commit 90bee71

Please sign in to comment.