Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable stubbing and function contracts for primitive types #3496

Merged
merged 6 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
267 changes: 137 additions & 130 deletions kani-compiler/src/kani_middle/resolve.rs

Large diffs are not rendered by default.

202 changes: 202 additions & 0 deletions kani-compiler/src/kani_middle/resolve/type_resolution.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
// Copyright Kani Contributors
// SPDX-License-Identifier: Apache-2.0 OR MIT
//! This module contains code used for resolve type / trait names

use crate::kani_middle::resolve::{resolve_path, validate_kind, ResolveError};
use quote::ToTokens;
use rustc_hir::def::DefKind;
use rustc_middle::ty::TyCtxt;
use rustc_smir::rustc_internal;
use rustc_span::def_id::LocalDefId;
use stable_mir::mir::Mutability;
use stable_mir::ty::{FloatTy, IntTy, Region, RegionKind, RigidTy, Ty, UintTy};
use std::str::FromStr;
use strum_macros::{EnumString, IntoStaticStr};
use syn::{Expr, ExprLit, Lit, Type, TypePath};
use tracing::{debug, debug_span};

/// Attempts to resolve a type from a type expression.
pub fn resolve_ty<'tcx>(
tcx: TyCtxt<'tcx>,
current_module: LocalDefId,
typ: &syn::Type,
) -> Result<Ty, ResolveError<'tcx>> {
let _span = debug_span!("resolve_ty", ?typ).entered();
debug!(?typ, ?current_module, "resolve_ty");
let unsupported = |kind: &'static str| Err(ResolveError::UnsupportedPath { kind });
let invalid = |kind: &'static str| {
Err(ResolveError::InvalidPath {
msg: format!("Expected a type, but found {kind} `{}`", typ.to_token_stream()),
})
};
#[warn(non_exhaustive_omitted_patterns)]
match typ {
Type::Path(TypePath { qself, path }) => {
assert_eq!(*qself, None, "Unexpected qualified path");
if let Some(primitive) =
path.get_ident().and_then(|ident| PrimitiveIdent::from_str(&ident.to_string()).ok())
{
Ok(primitive.into())
} else {
let def_id = resolve_path(tcx, current_module, path)?;
validate_kind!(
tcx,
def_id,
"type",
DefKind::Struct | DefKind::Union | DefKind::Enum
)?;
Ok(rustc_internal::stable(tcx.type_of(def_id)).value)
}
}
Type::Array(array) => {
let elem_ty = resolve_ty(tcx, current_module, &array.elem)?;
let len = parse_len(&array.len).map_err(|msg| ResolveError::InvalidPath { msg })?;
Ty::try_new_array(elem_ty, len.try_into().unwrap()).map_err(|err| {
ResolveError::InvalidPath { msg: format!("Cannot instantiate array. {err}") }
})
}
Type::Paren(inner) => resolve_ty(tcx, current_module, &inner.elem),
Type::Ptr(ptr) => {
let elem_ty = resolve_ty(tcx, current_module, &ptr.elem)?;
let mutability =
if ptr.mutability.is_some() { Mutability::Mut } else { Mutability::Not };
Ok(Ty::new_ptr(elem_ty, mutability))
}
Type::Reference(reference) => {
let elem_ty = resolve_ty(tcx, current_module, &reference.elem)?;
let mutability =
if reference.mutability.is_some() { Mutability::Mut } else { Mutability::Not };
Ok(Ty::new_ref(Region { kind: RegionKind::ReErased }, elem_ty, mutability))
}
Type::Slice(slice) => {
let elem_ty = resolve_ty(tcx, current_module, &slice.elem)?;
Ok(Ty::from_rigid_kind(RigidTy::Slice(elem_ty)))
}
Type::Tuple(tuple) => {
let elems = tuple
.elems
.iter()
.map(|elem| resolve_ty(tcx, current_module, &elem))
.collect::<Result<Vec<_>, _>>()?;
Ok(Ty::new_tuple(&elems))
}
Type::Never(_) => Ok(Ty::from_rigid_kind(RigidTy::Never)),
Type::BareFn(_) => unsupported("bare function"),
Type::Macro(_) => invalid("macro"),
Type::Group(_) => invalid("group paths"),
Type::ImplTrait(_) => invalid("trait impl paths"),
Type::Infer(_) => invalid("inferred paths"),
Type::TraitObject(_) => invalid("trait object paths"),
Type::Verbatim(_) => unsupported("unknown paths"),
_ => {
unreachable!()
}
}
}

/// Enumeration of existing primitive types that are not parametric.
#[derive(Copy, Clone, Debug, Eq, PartialEq, IntoStaticStr, EnumString)]
#[strum(serialize_all = "lowercase")]
pub(super) enum PrimitiveIdent {
Bool,
Char,
F16,
F32,
F64,
F128,
I8,
I16,
I32,
I64,
I128,
Isize,
Str,
U8,
U16,
U32,
U64,
U128,
Usize,
}

/// Convert a primitive ident into a primitive `Ty`.
impl From<PrimitiveIdent> for Ty {
fn from(value: PrimitiveIdent) -> Self {
match value {
PrimitiveIdent::Bool => Ty::bool_ty(),
PrimitiveIdent::Char => Ty::from_rigid_kind(RigidTy::Char),
PrimitiveIdent::F16 => Ty::from_rigid_kind(RigidTy::Float(FloatTy::F16)),
PrimitiveIdent::F32 => Ty::from_rigid_kind(RigidTy::Float(FloatTy::F32)),
PrimitiveIdent::F64 => Ty::from_rigid_kind(RigidTy::Float(FloatTy::F64)),
PrimitiveIdent::F128 => Ty::from_rigid_kind(RigidTy::Float(FloatTy::F128)),
PrimitiveIdent::I8 => Ty::signed_ty(IntTy::I8),
PrimitiveIdent::I16 => Ty::signed_ty(IntTy::I16),
PrimitiveIdent::I32 => Ty::signed_ty(IntTy::I32),
PrimitiveIdent::I64 => Ty::signed_ty(IntTy::I64),
PrimitiveIdent::I128 => Ty::signed_ty(IntTy::I128),
PrimitiveIdent::Isize => Ty::signed_ty(IntTy::Isize),
PrimitiveIdent::Str => Ty::from_rigid_kind(RigidTy::Str),
PrimitiveIdent::U8 => Ty::unsigned_ty(UintTy::U8),
PrimitiveIdent::U16 => Ty::unsigned_ty(UintTy::U16),
PrimitiveIdent::U32 => Ty::unsigned_ty(UintTy::U32),
PrimitiveIdent::U64 => Ty::unsigned_ty(UintTy::U64),
PrimitiveIdent::U128 => Ty::unsigned_ty(UintTy::U128),
PrimitiveIdent::Usize => Ty::unsigned_ty(UintTy::Usize),
}
}
}

/// Checks if a Path segment represents a primitive.
///
/// Note that this function will return false for expressions that cannot be parsed as a type.
pub(super) fn is_primitive<T>(path: &T) -> bool
where
T: ToTokens,
{
let token = path.to_token_stream();
let Ok(typ) = syn::parse2(token) else { return false };
is_type_primitive(&typ)
}

/// Checks if a type is a primitive including composite ones.
pub(super) fn is_type_primitive(typ: &syn::Type) -> bool {
#[warn(non_exhaustive_omitted_patterns)]
match typ {
Type::Array(_)
| Type::Ptr(_)
| Type::Reference(_)
| Type::Slice(_)
| Type::Never(_)
| Type::Tuple(_) => true,
Type::Path(TypePath { qself: Some(qself), path }) => {
path.segments.is_empty() && is_type_primitive(&qself.ty)
}
Type::Path(TypePath { qself: None, path }) => path
.get_ident()
.map_or(false, |ident| PrimitiveIdent::from_str(&ident.to_string()).is_ok()),
Type::BareFn(_)
| Type::Group(_)
| Type::ImplTrait(_)
| Type::Infer(_)
| Type::Macro(_)
| Type::Paren(_)
| Type::TraitObject(_)
| Type::Verbatim(_) => false,
_ => {
unreachable!()
}
}
}

/// Parse the length of the array.
/// We currently only support a constant length.
fn parse_len(len: &Expr) -> Result<usize, String> {
if let Expr::Lit(ExprLit { lit: Lit::Int(lit), .. }) = len {
if matches!(lit.suffix(), "" | "usize")
&& let Ok(val) = usize::from_str(lit.base10_digits())
{
return Ok(val);
}
}
Err(format!("Expected a `usize` constant, but found `{}`", len.to_token_stream()))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright Kani Contributors
// SPDX-License-Identifier: Apache-2.0 OR MIT
//
// kani-flags: -Z stubbing

//! Kani supports stubbing of most primitive methods, however, some methods, such as len, may be
//! lowered to an Rvalue.

/// Check that we can stub slices methods
pub mod slices_check {
#[derive(kani::Arbitrary)]
pub struct MyStruct(u8, i32);

pub fn stub_len_is_10<T>(_: &[T]) -> usize {
10
}

// This fails since `<[T]>::len` is lowered to `Rvalue::Len`.
#[kani::proof]
#[kani::stub(<[MyStruct]>::len, stub_len_is_10)]
pub fn check_stub_len_is_10() {
let input: [MyStruct; 5] = kani::any();
let slice = kani::slice::any_slice_of_array(&input);
assert_eq!(slice.len(), 10);
}
}
18 changes: 18 additions & 0 deletions tests/kani/Stubbing/StubPrimitives/stub_char_methods.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright Kani Contributors
// SPDX-License-Identifier: Apache-2.0 OR MIT
//
// kani-flags: -Z stubbing
//
//! This tests that we can correctly stub char functions.
/// Check that we can stub is_ascii from `char`.
pub fn stub_is_ascii_true(_: &char) -> bool {
true
}

/// Check stubbing by directly calling `str::is_ascii`
#[kani::proof]
#[kani::stub(char::is_ascii, stub_is_ascii_true)]
pub fn check_stub_is_ascii() {
let input: char = kani::any();
assert!(input.is_ascii());
}
39 changes: 39 additions & 0 deletions tests/kani/Stubbing/StubPrimitives/stub_int_methods.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright Kani Contributors
// SPDX-License-Identifier: Apache-2.0 OR MIT
//
// kani-flags: -Z stubbing
//
//! This tests that we can correctly stub integer types functions.

/// Generate stub and harness for count_ones method on integers.
macro_rules! stub_count_ones {
($ty:ty, $harness:ident, $stub:ident) => {
// Stub that always returns 0.
pub fn $stub(_: $ty) -> u32 {
0
}

// Harness
#[kani::proof]
#[kani::stub($ty::count_ones, $stub)]
pub fn $harness() {
let input = kani::any();
let ones = <$ty>::count_ones(input);
assert_eq!(ones, 0);
}
};
}

stub_count_ones!(u8, u8_count_ones, stub_u8_count_ones);
stub_count_ones!(u16, u16_count_ones, stub_u16_count_ones);
stub_count_ones!(u32, u32_count_ones, stub_u32_count_ones);
stub_count_ones!(u64, u64_count_ones, stub_u64_count_ones);
stub_count_ones!(u128, u128_count_ones, stub_u128_count_ones);
stub_count_ones!(usize, usize_count_ones, stub_usize_count_ones);

stub_count_ones!(i8, i8_count_ones, stub_i8_count_ones);
stub_count_ones!(i16, i16_count_ones, stub_i16_count_ones);
stub_count_ones!(i32, i32_count_ones, stub_i32_count_ones);
stub_count_ones!(i64, i64_count_ones, stub_i64_count_ones);
stub_count_ones!(i128, i128_count_ones, stub_i128_count_ones);
stub_count_ones!(isize, isize_count_ones, stub_isize_count_ones);
38 changes: 38 additions & 0 deletions tests/kani/Stubbing/StubPrimitives/stub_ptr_methods.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright Kani Contributors
// SPDX-License-Identifier: Apache-2.0 OR MIT
//
// kani-flags: -Z stubbing
//
//! This tests that we can correctly stub methods from raw pointer types.

pub fn stub_len_is_10<T>(_: *const [T]) -> usize {
10
}

pub fn stub_mut_len_is_0<T>(_: *mut [T]) -> usize {
0
}

#[kani::proof]
#[kani::stub(<*const [u8]>::len, stub_len_is_10)]
#[kani::stub(<*mut [u8]>::len, stub_mut_len_is_0)]
pub fn check_stub_len_raw_ptr() {
let mut input: [u8; 5] = kani::any();
let mut_ptr = &mut input as *mut [u8];
let ptr = &input as *const [u8];
assert_eq!(mut_ptr.len(), 0);
assert_eq!(ptr.len(), 10);
}

pub fn stub_is_always_null<T>(_: *const T) -> bool {
true
}

// Fix-me: Option doesn't seem to work without the fully qualified path.
#[kani::proof]
#[kani::stub(<*const std::option::Option>::is_null, stub_is_always_null)]
pub fn check_stub_is_null() {
let input: Option<char> = kani::any();
let ptr = &input as *const Option<char>;
assert!(unsafe { ptr.is_null() });
}
37 changes: 37 additions & 0 deletions tests/kani/Stubbing/StubPrimitives/stub_slice_methods.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright Kani Contributors
// SPDX-License-Identifier: Apache-2.0 OR MIT
//
// kani-flags: -Z stubbing
//
//! This tests that we can correctly stub slices and string slices functions.

/// Check that we can stub str::is_ascii
pub mod str_check {
pub fn stub_is_ascii_false(_: &str) -> bool {
false
}

#[kani::proof]
#[kani::stub(str::is_ascii, stub_is_ascii_false)]
pub fn check_stub_is_ascii() {
let input = "is_ascii";
assert!(!input.is_ascii());
}
}

/// Check that we can stub slices
pub mod slices_check {
#[derive(kani::Arbitrary, Ord, PartialOrd, Copy, Clone, PartialEq, Eq)]
pub struct MyStruct(u8, i32);

pub fn stub_sort_noop<T>(_: &mut [T]) {}

#[kani::proof]
#[kani::stub(<[MyStruct]>::sort, stub_sort_noop)]
pub fn check_stub_sort_noop() {
let mut input: [MyStruct; 5] = kani::any();
let copy = input.clone();
input.sort();
assert_eq!(input, copy);
}
}
Loading
Loading