From e4a90e9aa366142e18d1ba01ea42136700da8f2f Mon Sep 17 00:00:00 2001 From: "Celina G. Val" Date: Mon, 25 Mar 2024 11:29:03 -0700 Subject: [PATCH] Implement validity checks (#3085) This is still incomplete, but hopefully it can be merged as an unstable feature. I'll publish an RFC shortly. This instruments the function body with assertion checks to see if users are generating invalid values. This covers: - Union access - Raw pointer dereference - Transmute value - Field assignment of struct with invalid values - Aggregate assignment Things not covered today should trigger ICE or a delayed verification failure due to unsupported feature. ## Design This change has two main design changes which are inside the new `kani_compiler::kani_middle::transform` module: 1- Instance body should now be retrieved from the `BodyTransformation` structure. This structure will run transformation passes on instance bodies (i.e.: monomorphic instances) and cache the result. 2- Create a new transformation pass that instruments the body of a function for every potential invalid value generation. 3- Create a body builder which contains all elements of a function body and mutable functions to modify them accordingly. Related to #2998 Fixes #301 --------- Co-authored-by: Zyad Hassan <88045115+zhassan-aws@users.noreply.github.com> --- kani-compiler/src/args.rs | 11 + .../codegen_cprover_gotoc/codegen/function.rs | 4 +- .../compiler_interface.rs | 25 +- .../codegen_cprover_gotoc/context/goto_ctx.rs | 5 + .../codegen_cprover_gotoc/overrides/hooks.rs | 26 +- kani-compiler/src/kani_middle/attributes.rs | 11 + kani-compiler/src/kani_middle/mod.rs | 18 +- kani-compiler/src/kani_middle/provide.rs | 4 +- kani-compiler/src/kani_middle/reachability.rs | 28 +- .../src/kani_middle/transform/body.rs | 275 ++++++ .../src/kani_middle/transform/check_values.rs | 924 ++++++++++++++++++ .../src/kani_middle/transform/mod.rs | 135 +++ kani-driver/src/call_single_file.rs | 5 + kani_metadata/src/unstable.rs | 3 + tests/kani/ValidValues/constants.rs | 40 + tests/kani/ValidValues/custom_niche.rs | 100 ++ tests/kani/ValidValues/non_null.rs | 26 + tests/kani/ValidValues/write_invalid.rs | 37 + 18 files changed, 1641 insertions(+), 36 deletions(-) create mode 100644 kani-compiler/src/kani_middle/transform/body.rs create mode 100644 kani-compiler/src/kani_middle/transform/check_values.rs create mode 100644 kani-compiler/src/kani_middle/transform/mod.rs create mode 100644 tests/kani/ValidValues/constants.rs create mode 100644 tests/kani/ValidValues/custom_niche.rs create mode 100644 tests/kani/ValidValues/non_null.rs create mode 100644 tests/kani/ValidValues/write_invalid.rs diff --git a/kani-compiler/src/args.rs b/kani-compiler/src/args.rs index 69a82b61e19d..fcc346bd745f 100644 --- a/kani-compiler/src/args.rs +++ b/kani-compiler/src/args.rs @@ -71,4 +71,15 @@ pub struct Arguments { #[clap(long)] /// A legacy flag that is now ignored. goto_c: bool, + /// Enable specific checks. + #[clap(long)] + pub ub_check: Vec, +} + +#[derive(Debug, Clone, Copy, AsRefStr, EnumString, VariantNames, PartialEq, Eq)] +#[strum(serialize_all = "snake_case")] +pub enum ExtraChecks { + /// Check that produced values are valid except for uninitialized values. + /// See https://github.com/model-checking/kani/issues/920. + Validity, } diff --git a/kani-compiler/src/codegen_cprover_gotoc/codegen/function.rs b/kani-compiler/src/codegen_cprover_gotoc/codegen/function.rs index 7cf4b979f407..d55696bfdc87 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/codegen/function.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/codegen/function.rs @@ -60,7 +60,7 @@ impl<'tcx> GotocCtx<'tcx> { debug!("Double codegen of {:?}", old_sym); } else { assert!(old_sym.is_function()); - let body = instance.body().unwrap(); + let body = self.transformer.body(self.tcx, instance); self.set_current_fn(instance, &body); self.print_instance(instance, &body); self.codegen_function_prelude(&body); @@ -201,7 +201,7 @@ impl<'tcx> GotocCtx<'tcx> { pub fn declare_function(&mut self, instance: Instance) { debug!("declaring {}; {:?}", instance.name(), instance); - let body = instance.body().unwrap(); + let body = self.transformer.body(self.tcx, instance); self.set_current_fn(instance, &body); debug!(krate=?instance.def.krate(), is_std=self.current_fn().is_std(), "declare_function"); self.ensure(&self.symbol_name_stable(instance), |ctx, fname| { diff --git a/kani-compiler/src/codegen_cprover_gotoc/compiler_interface.rs b/kani-compiler/src/codegen_cprover_gotoc/compiler_interface.rs index cad8bcbfb9ce..d3b4ae8473de 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/compiler_interface.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/compiler_interface.rs @@ -12,6 +12,7 @@ use crate::kani_middle::provide; use crate::kani_middle::reachability::{ collect_reachable_items, filter_const_crate_items, filter_crate_items, }; +use crate::kani_middle::transform::BodyTransformation; use crate::kani_middle::{check_reachable_items, dump_mir_items}; use crate::kani_queries::QueryDb; use cbmc::goto_program::Location; @@ -86,16 +87,18 @@ impl GotocCodegenBackend { symtab_goto: &Path, machine_model: &MachineModel, check_contract: Option, + mut transformer: BodyTransformation, ) -> (GotocCtx<'tcx>, Vec, Option) { let items = with_timer( - || collect_reachable_items(tcx, starting_items), + || collect_reachable_items(tcx, &mut transformer, starting_items), "codegen reachability analysis", ); dump_mir_items(tcx, &items, &symtab_goto.with_extension("kani.mir")); // Follow rustc naming convention (cx is abbrev for context). // https://rustc-dev-guide.rust-lang.org/conventions.html#naming-conventions - let mut gcx = GotocCtx::new(tcx, (*self.queries.lock().unwrap()).clone(), machine_model); + let mut gcx = + GotocCtx::new(tcx, (*self.queries.lock().unwrap()).clone(), machine_model, transformer); check_reachable_items(gcx.tcx, &gcx.queries, &items); let contract_info = with_timer( @@ -227,6 +230,7 @@ impl CodegenBackend for GotocCodegenBackend { // - None: Don't generate code. This is used to compile dependencies. let base_filename = tcx.output_filenames(()).output_path(OutputType::Object); let reachability = queries.args().reachability_analysis; + let mut transformer = BodyTransformation::new(&queries, tcx); let mut results = GotoCodegenResults::new(tcx, reachability); match reachability { ReachabilityType::Harnesses => { @@ -248,8 +252,9 @@ impl CodegenBackend for GotocCodegenBackend { model_path, &results.machine_model, contract_metadata, + transformer, ); - results.extend(gcx, items, None); + transformer = results.extend(gcx, items, None); if let Some(assigns_contract) = contract_info { self.queries.lock().unwrap().register_assigns_contract( canonical_mangled_name(harness).intern(), @@ -263,7 +268,7 @@ impl CodegenBackend for GotocCodegenBackend { // test closure that we want to execute // TODO: Refactor this code so we can guarantee that the pair (test_fn, test_desc) actually match. let mut descriptions = vec![]; - let harnesses = filter_const_crate_items(tcx, |_, item| { + let harnesses = filter_const_crate_items(tcx, &mut transformer, |_, item| { if is_test_harness_description(tcx, item.def) { descriptions.push(item.def); true @@ -282,6 +287,7 @@ impl CodegenBackend for GotocCodegenBackend { &model_path, &results.machine_model, Default::default(), + transformer, ); results.extend(gcx, items, None); @@ -319,9 +325,10 @@ impl CodegenBackend for GotocCodegenBackend { &model_path, &results.machine_model, Default::default(), + transformer, ); assert!(contract_info.is_none()); - results.extend(gcx, items, None); + let _ = results.extend(gcx, items, None); } } @@ -613,12 +620,18 @@ impl GotoCodegenResults { } } - fn extend(&mut self, gcx: GotocCtx, items: Vec, metadata: Option) { + fn extend( + &mut self, + gcx: GotocCtx, + items: Vec, + metadata: Option, + ) -> BodyTransformation { let mut items = items; self.harnesses.extend(metadata); self.concurrent_constructs.extend(gcx.concurrent_constructs); self.unsupported_constructs.extend(gcx.unsupported_constructs); self.items.append(&mut items); + gcx.transformer } /// Prints a report at the end of the compilation. diff --git a/kani-compiler/src/codegen_cprover_gotoc/context/goto_ctx.rs b/kani-compiler/src/codegen_cprover_gotoc/context/goto_ctx.rs index 10ee6a876588..a25e502aa574 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/context/goto_ctx.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/context/goto_ctx.rs @@ -18,6 +18,7 @@ use super::vtable_ctx::VtableCtx; use crate::codegen_cprover_gotoc::overrides::{fn_hooks, GotocHooks}; use crate::codegen_cprover_gotoc::utils::full_crate_name; use crate::codegen_cprover_gotoc::UnsupportedConstructs; +use crate::kani_middle::transform::BodyTransformation; use crate::kani_queries::QueryDb; use cbmc::goto_program::{DatatypeComponent, Expr, Location, Stmt, Symbol, SymbolTable, Type}; use cbmc::utils::aggr_tag; @@ -70,6 +71,8 @@ pub struct GotocCtx<'tcx> { /// We collect them and print one warning at the end if not empty instead of printing one /// warning at each occurrence. pub concurrent_constructs: UnsupportedConstructs, + /// The body transformation agent. + pub transformer: BodyTransformation, } /// Constructor @@ -78,6 +81,7 @@ impl<'tcx> GotocCtx<'tcx> { tcx: TyCtxt<'tcx>, queries: QueryDb, machine_model: &MachineModel, + transformer: BodyTransformation, ) -> GotocCtx<'tcx> { let fhks = fn_hooks(); let symbol_table = SymbolTable::new(machine_model.clone()); @@ -99,6 +103,7 @@ impl<'tcx> GotocCtx<'tcx> { global_checks_count: 0, unsupported_constructs: FxHashMap::default(), concurrent_constructs: FxHashMap::default(), + transformer, } } } diff --git a/kani-compiler/src/codegen_cprover_gotoc/overrides/hooks.rs b/kani-compiler/src/codegen_cprover_gotoc/overrides/hooks.rs index 2fcbaa36fb44..18cc44b7b20d 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/overrides/hooks.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/overrides/hooks.rs @@ -10,6 +10,7 @@ use crate::codegen_cprover_gotoc::codegen::{bb_label, PropertyClass}; use crate::codegen_cprover_gotoc::GotocCtx; +use crate::kani_middle::attributes::matches_diagnostic as matches_function; use crate::unwrap_or_return_codegen_unimplemented_stmt; use cbmc::goto_program::{BuiltinFn, Expr, Location, Stmt, Type}; use rustc_middle::ty::TyCtxt; @@ -35,17 +36,6 @@ pub trait GotocHook { ) -> Stmt; } -fn matches_function(tcx: TyCtxt, instance: Instance, attr_name: &str) -> bool { - let attr_sym = rustc_span::symbol::Symbol::intern(attr_name); - if let Some(attr_id) = tcx.all_diagnostic_items(()).name_to_id.get(&attr_sym) { - if rustc_internal::internal(tcx, instance.def.def_id()) == *attr_id { - debug!("matched: {:?} {:?}", attr_id, attr_sym); - return true; - } - } - false -} - /// A hook for Kani's `cover` function (declared in `library/kani/src/lib.rs`). /// The function takes two arguments: a condition expression (bool) and a /// message (&'static str). @@ -57,7 +47,7 @@ fn matches_function(tcx: TyCtxt, instance: Instance, attr_name: &str) -> bool { struct Cover; impl GotocHook for Cover { fn hook_applies(&self, tcx: TyCtxt, instance: Instance) -> bool { - matches_function(tcx, instance, "KaniCover") + matches_function(tcx, instance.def, "KaniCover") } fn handle( @@ -92,7 +82,7 @@ impl GotocHook for Cover { struct Assume; impl GotocHook for Assume { fn hook_applies(&self, tcx: TyCtxt, instance: Instance) -> bool { - matches_function(tcx, instance, "KaniAssume") + matches_function(tcx, instance.def, "KaniAssume") } fn handle( @@ -116,7 +106,7 @@ impl GotocHook for Assume { struct Assert; impl GotocHook for Assert { fn hook_applies(&self, tcx: TyCtxt, instance: Instance) -> bool { - matches_function(tcx, instance, "KaniAssert") + matches_function(tcx, instance.def, "KaniAssert") } fn handle( @@ -157,7 +147,7 @@ struct Nondet; impl GotocHook for Nondet { fn hook_applies(&self, tcx: TyCtxt, instance: Instance) -> bool { - matches_function(tcx, instance, "KaniAnyRaw") + matches_function(tcx, instance.def, "KaniAnyRaw") } fn handle( @@ -201,7 +191,7 @@ impl GotocHook for Panic { || tcx.has_attr(def_id, rustc_span::sym::rustc_const_panic_str) || Some(def_id) == tcx.lang_items().panic_fmt() || Some(def_id) == tcx.lang_items().begin_panic_fn() - || matches_function(tcx, instance, "KaniPanic") + || matches_function(tcx, instance.def, "KaniPanic") } fn handle( @@ -221,7 +211,7 @@ impl GotocHook for Panic { struct IsReadOk; impl GotocHook for IsReadOk { fn hook_applies(&self, tcx: TyCtxt, instance: Instance) -> bool { - matches_function(tcx, instance, "KaniIsReadOk") + matches_function(tcx, instance.def, "KaniIsReadOk") } fn handle( @@ -365,7 +355,7 @@ struct UntrackedDeref; impl GotocHook for UntrackedDeref { fn hook_applies(&self, tcx: TyCtxt, instance: Instance) -> bool { - matches_function(tcx, instance, "KaniUntrackedDeref") + matches_function(tcx, instance.def, "KaniUntrackedDeref") } fn handle( diff --git a/kani-compiler/src/kani_middle/attributes.rs b/kani-compiler/src/kani_middle/attributes.rs index df494d7daa3c..979877199dde 100644 --- a/kani-compiler/src/kani_middle/attributes.rs +++ b/kani-compiler/src/kani_middle/attributes.rs @@ -1037,3 +1037,14 @@ fn attr_kind(tcx: TyCtxt, attr: &Attribute) -> Option { _ => None, } } + +pub fn matches_diagnostic(tcx: TyCtxt, def: T, attr_name: &str) -> bool { + let attr_sym = rustc_span::symbol::Symbol::intern(attr_name); + if let Some(attr_id) = tcx.all_diagnostic_items(()).name_to_id.get(&attr_sym) { + if rustc_internal::internal(tcx, def.def_id()) == *attr_id { + debug!("matched: {:?} {:?}", attr_id, attr_sym); + return true; + } + } + false +} diff --git a/kani-compiler/src/kani_middle/mod.rs b/kani-compiler/src/kani_middle/mod.rs index e65befc9624e..4cd2d4a357d7 100644 --- a/kani-compiler/src/kani_middle/mod.rs +++ b/kani-compiler/src/kani_middle/mod.rs @@ -23,7 +23,7 @@ use rustc_target::abi::call::FnAbi; use rustc_target::abi::{HasDataLayout, TargetDataLayout}; use stable_mir::mir::mono::{Instance, InstanceKind, MonoItem}; use stable_mir::mir::pretty::pretty_ty; -use stable_mir::ty::{BoundVariableKind, RigidTy, Span as SpanStable, Ty, TyKind}; +use stable_mir::ty::{BoundVariableKind, FnDef, RigidTy, Span as SpanStable, Ty, TyKind}; use stable_mir::visitor::{Visitable, Visitor as TypeVisitor}; use stable_mir::{CrateDef, DefId}; use std::fs::File; @@ -41,6 +41,7 @@ pub mod provide; pub mod reachability; pub mod resolve; pub mod stubbing; +pub mod transform; /// Check that all crate items are supported and there's no misconfiguration. /// This method will exhaustively print any error / warning and it will abort at the end if any @@ -316,3 +317,18 @@ impl<'tcx> FnAbiOfHelpers<'tcx> for CompilerHelpers<'tcx> { } } } + +/// Find an instance of a function from the given crate that has been annotated with `diagnostic` +/// item. +fn find_fn_def(tcx: TyCtxt, diagnostic: &str) -> Option { + let attr_id = tcx + .all_diagnostic_items(()) + .name_to_id + .get(&rustc_span::symbol::Symbol::intern(diagnostic))?; + let TyKind::RigidTy(RigidTy::FnDef(def, _)) = + rustc_internal::stable(tcx.type_of(attr_id)).value.kind() + else { + return None; + }; + Some(def) +} diff --git a/kani-compiler/src/kani_middle/provide.rs b/kani-compiler/src/kani_middle/provide.rs index d5495acb67a7..d29635ecfd15 100644 --- a/kani-compiler/src/kani_middle/provide.rs +++ b/kani-compiler/src/kani_middle/provide.rs @@ -8,6 +8,7 @@ use crate::args::{Arguments, ReachabilityType}; use crate::kani_middle::intrinsics::ModelIntrinsics; use crate::kani_middle::reachability::{collect_reachable_items, filter_crate_items}; use crate::kani_middle::stubbing; +use crate::kani_middle::transform::BodyTransformation; use crate::kani_queries::QueryDb; use rustc_hir::def_id::{DefId, LocalDefId}; use rustc_middle::util::Providers; @@ -79,8 +80,9 @@ fn collect_and_partition_mono_items( rustc_smir::rustc_internal::run(tcx, || { let local_reachable = filter_crate_items(tcx, |_, _| true).into_iter().map(MonoItem::Fn).collect::>(); + // We do not actually need the value returned here. - collect_reachable_items(tcx, &local_reachable); + collect_reachable_items(tcx, &mut BodyTransformation::dummy(), &local_reachable); }) .unwrap(); (rustc_interface::DEFAULT_QUERY_PROVIDERS.collect_and_partition_mono_items)(tcx, key) diff --git a/kani-compiler/src/kani_middle/reachability.rs b/kani-compiler/src/kani_middle/reachability.rs index 2fa1bf057c1c..b46fbdccc016 100644 --- a/kani-compiler/src/kani_middle/reachability.rs +++ b/kani-compiler/src/kani_middle/reachability.rs @@ -37,12 +37,17 @@ use stable_mir::{CrateDef, ItemKind}; use crate::kani_middle::coercion; use crate::kani_middle::coercion::CoercionBase; use crate::kani_middle::stubbing::{get_stub, validate_instance}; +use crate::kani_middle::transform::BodyTransformation; /// Collect all reachable items starting from the given starting points. -pub fn collect_reachable_items(tcx: TyCtxt, starting_points: &[MonoItem]) -> Vec { +pub fn collect_reachable_items( + tcx: TyCtxt, + transformer: &mut BodyTransformation, + starting_points: &[MonoItem], +) -> Vec { // For each harness, collect items using the same collector. // I.e.: This will return any item that is reachable from one or more of the starting points. - let mut collector = MonoItemsCollector::new(tcx); + let mut collector = MonoItemsCollector::new(tcx, transformer); for item in starting_points { collector.collect(item.clone()); } @@ -92,7 +97,11 @@ where /// /// Probably only specifically useful with a predicate to find `TestDescAndFn` const declarations from /// tests and extract the closures from them. -pub fn filter_const_crate_items(tcx: TyCtxt, mut predicate: F) -> Vec +pub fn filter_const_crate_items( + tcx: TyCtxt, + transformer: &mut BodyTransformation, + mut predicate: F, +) -> Vec where F: FnMut(TyCtxt, Instance) -> bool, { @@ -103,7 +112,7 @@ where // Only collect monomorphic items. if let Ok(instance) = Instance::try_from(item) { if predicate(tcx, instance) { - let body = instance.body().unwrap(); + let body = transformer.body(tcx, instance); let mut collector = MonoItemsFnCollector { tcx, body: &body, @@ -118,9 +127,11 @@ where roots } -struct MonoItemsCollector<'tcx> { +struct MonoItemsCollector<'tcx, 'a> { /// The compiler context. tcx: TyCtxt<'tcx>, + /// The body transformation object used to retrieve a transformed body. + transformer: &'a mut BodyTransformation, /// Set of collected items used to avoid entering recursion loops. collected: FxHashSet, /// Items enqueued for visiting. @@ -129,14 +140,15 @@ struct MonoItemsCollector<'tcx> { call_graph: debug::CallGraph, } -impl<'tcx> MonoItemsCollector<'tcx> { - pub fn new(tcx: TyCtxt<'tcx>) -> Self { +impl<'tcx, 'a> MonoItemsCollector<'tcx, 'a> { + pub fn new(tcx: TyCtxt<'tcx>, transformer: &'a mut BodyTransformation) -> Self { MonoItemsCollector { tcx, collected: FxHashSet::default(), queue: vec![], #[cfg(debug_assertions)] call_graph: debug::CallGraph::default(), + transformer, } } @@ -174,7 +186,7 @@ impl<'tcx> MonoItemsCollector<'tcx> { fn visit_fn(&mut self, instance: Instance) -> Vec { let _guard = debug_span!("visit_fn", function=?instance).entered(); if validate_instance(self.tcx, instance) { - let body = instance.body().unwrap(); + let body = self.transformer.body(self.tcx, instance); let mut collector = MonoItemsFnCollector { tcx: self.tcx, collected: FxHashSet::default(), diff --git a/kani-compiler/src/kani_middle/transform/body.rs b/kani-compiler/src/kani_middle/transform/body.rs new file mode 100644 index 000000000000..32ffb373d767 --- /dev/null +++ b/kani-compiler/src/kani_middle/transform/body.rs @@ -0,0 +1,275 @@ +// Copyright Kani Contributors +// SPDX-License-Identifier: Apache-2.0 OR MIT +// +//! Utility functions that allow us to modify a function body. + +use crate::kani_middle::find_fn_def; +use rustc_middle::ty::TyCtxt; +use stable_mir::mir::mono::Instance; +use stable_mir::mir::{ + BasicBlock, BasicBlockIdx, BinOp, Body, CastKind, Constant, Local, LocalDecl, Mutability, + Operand, Place, Rvalue, Statement, StatementKind, Terminator, TerminatorKind, UnwindAction, + VarDebugInfo, +}; +use stable_mir::ty::{Const, GenericArgs, Span, Ty, UintTy}; +use std::mem; + +/// This structure mimics a Body that can actually be modified. +pub struct MutableBody { + blocks: Vec, + + /// Declarations of locals within the function. + /// + /// The first local is the return value pointer, followed by `arg_count` + /// locals for the function arguments, followed by any user-declared + /// variables and temporaries. + locals: Vec, + + /// The number of arguments this function takes. + arg_count: usize, + + /// Debug information pertaining to user variables, including captures. + var_debug_info: Vec, + + /// Mark an argument (which must be a tuple) as getting passed as its individual components. + /// + /// This is used for the "rust-call" ABI such as closures. + spread_arg: Option, + + /// The span that covers the entire function body. + span: Span, +} + +impl MutableBody { + /// Get the basic blocks of this builder. + pub fn blocks(&self) -> &[BasicBlock] { + &self.blocks + } + + pub fn locals(&self) -> &[LocalDecl] { + &self.locals + } + + /// Create a mutable body from the original MIR body. + pub fn from(body: Body) -> Self { + MutableBody { + locals: body.locals().to_vec(), + arg_count: body.arg_locals().len(), + spread_arg: body.spread_arg(), + blocks: body.blocks, + var_debug_info: body.var_debug_info, + span: body.span, + } + } + + /// Create the new body consuming this mutable body. + pub fn into(self) -> Body { + Body::new( + self.blocks, + self.locals, + self.arg_count, + self.var_debug_info, + self.spread_arg, + self.span, + ) + } + + /// Add a new local to the body with the given attributes. + pub fn new_local(&mut self, ty: Ty, span: Span, mutability: Mutability) -> Local { + let decl = LocalDecl { ty, span, mutability }; + let local = self.locals.len(); + self.locals.push(decl); + local + } + + pub fn new_str_operand(&mut self, msg: &str, span: Span) -> Operand { + let literal = Const::from_str(msg); + Operand::Constant(Constant { span, user_ty: None, literal }) + } + + pub fn new_const_operand(&mut self, val: u128, uint_ty: UintTy, span: Span) -> Operand { + let literal = Const::try_from_uint(val, uint_ty).unwrap(); + Operand::Constant(Constant { span, user_ty: None, literal }) + } + + /// Create a raw pointer of `*mut type` and return a new local where that value is stored. + pub fn new_cast_ptr( + &mut self, + from: Operand, + pointee_ty: Ty, + mutability: Mutability, + before: &mut SourceInstruction, + ) -> Local { + assert!(from.ty(self.locals()).unwrap().kind().is_raw_ptr()); + let target_ty = Ty::new_ptr(pointee_ty, mutability); + let rvalue = Rvalue::Cast(CastKind::PtrToPtr, from, target_ty); + self.new_assignment(rvalue, before) + } + + /// Add a new assignment for the given binary operation. + /// + /// Return the local where the result is saved. + pub fn new_binary_op( + &mut self, + bin_op: BinOp, + lhs: Operand, + rhs: Operand, + before: &mut SourceInstruction, + ) -> Local { + let rvalue = Rvalue::BinaryOp(bin_op, lhs, rhs); + self.new_assignment(rvalue, before) + } + + /// Add a new assignment. + /// + /// Return local where the result is saved. + pub fn new_assignment(&mut self, rvalue: Rvalue, before: &mut SourceInstruction) -> Local { + let span = before.span(&self.blocks); + let ret_ty = rvalue.ty(&self.locals).unwrap(); + let result = self.new_local(ret_ty, span, Mutability::Not); + let stmt = Statement { kind: StatementKind::Assign(Place::from(result), rvalue), span }; + self.insert_stmt(stmt, before); + result + } + + /// Add a new assert to the basic block indicated by the given index. + /// + /// The new assertion will have the same span as the source instruction, and the basic block + /// will be split. The source instruction will be adjusted to point to the first instruction in + /// the new basic block. + pub fn add_check( + &mut self, + tcx: TyCtxt, + check_type: &CheckType, + source: &mut SourceInstruction, + value: Local, + msg: &str, + ) { + assert_eq!( + self.locals[value].ty, + Ty::bool_ty(), + "Expected boolean value as the assert input" + ); + let new_bb = self.blocks.len(); + let span = source.span(&self.blocks); + match check_type { + CheckType::Assert(assert_fn) => { + let assert_op = Operand::Copy(Place::from(self.new_local( + assert_fn.ty(), + span, + Mutability::Not, + ))); + let msg_op = self.new_str_operand(msg, span); + let kind = TerminatorKind::Call { + func: assert_op, + args: vec![Operand::Move(Place::from(value)), msg_op], + destination: Place { + local: self.new_local(Ty::new_tuple(&[]), span, Mutability::Not), + projection: vec![], + }, + target: Some(new_bb), + unwind: UnwindAction::Terminate, + }; + let terminator = Terminator { kind, span }; + self.split_bb(source, terminator); + } + CheckType::Panic(..) | CheckType::NoCore => { + tcx.sess + .dcx() + .struct_err("Failed to instrument the code. Cannot find `kani::assert`") + .with_note("Kani requires `kani` library in order to verify a crate.") + .emit(); + tcx.sess.dcx().abort_if_errors(); + unreachable!(); + } + } + } + + /// Split a basic block right before the source location and use the new terminator + /// in the basic block that was split. + /// + /// The source is updated to point to the same instruction which is now in the new basic block. + pub fn split_bb(&mut self, source: &mut SourceInstruction, new_term: Terminator) { + let new_bb_idx = self.blocks.len(); + let (idx, bb) = match source { + SourceInstruction::Statement { idx, bb } => { + let (orig_idx, orig_bb) = (*idx, *bb); + *idx = 0; + *bb = new_bb_idx; + (orig_idx, orig_bb) + } + SourceInstruction::Terminator { bb } => { + let orig_bb = *bb; + *bb = new_bb_idx; + (self.blocks[orig_bb].statements.len(), orig_bb) + } + }; + let old_term = mem::replace(&mut self.blocks[bb].terminator, new_term); + let bb_stmts = &mut self.blocks[bb].statements; + let remaining = bb_stmts.split_off(idx); + let new_bb = BasicBlock { statements: remaining, terminator: old_term }; + self.blocks.push(new_bb); + } + + /// Insert statement before the source instruction and update the source as needed. + pub fn insert_stmt(&mut self, new_stmt: Statement, before: &mut SourceInstruction) { + match before { + SourceInstruction::Statement { idx, bb } => { + self.blocks[*bb].statements.insert(*idx, new_stmt); + *idx += 1; + } + SourceInstruction::Terminator { bb } => { + // Append statements at the end of the basic block. + self.blocks[*bb].statements.push(new_stmt); + } + } + } +} + +#[derive(Clone, Debug)] +pub enum CheckType { + /// This is used by default when the `kani` crate is available. + Assert(Instance), + /// When the `kani` crate is not available, we have to model the check as an `if { panic!() }`. + Panic(Instance), + /// When building non-core crate, such as `rustc-std-workspace-core`, we cannot + /// instrument code, but we can still compile them. + NoCore, +} + +impl CheckType { + /// This will create the type of check that is available in the current crate. + /// + /// If `kani` crate is available, this will return [CheckType::Assert], and the instance will + /// point to `kani::assert`. Otherwise, we will collect the `core::panic_str` method and return + /// [CheckType::Panic]. + pub fn new(tcx: TyCtxt) -> CheckType { + if let Some(instance) = find_instance(tcx, "KaniAssert") { + CheckType::Assert(instance) + } else if let Some(instance) = find_instance(tcx, "panic_str") { + CheckType::Panic(instance) + } else { + CheckType::NoCore + } + } +} + +/// We store the index of an instruction to avoid borrow checker issues and unnecessary copies. +#[derive(Copy, Clone, Debug)] +pub enum SourceInstruction { + Statement { idx: usize, bb: BasicBlockIdx }, + Terminator { bb: BasicBlockIdx }, +} + +impl SourceInstruction { + pub fn span(&self, blocks: &[BasicBlock]) -> Span { + match *self { + SourceInstruction::Statement { idx, bb } => blocks[bb].statements[idx].span, + SourceInstruction::Terminator { bb } => blocks[bb].terminator.span, + } + } +} + +fn find_instance(tcx: TyCtxt, diagnostic: &str) -> Option { + Instance::resolve(find_fn_def(tcx, diagnostic)?, &GenericArgs(vec![])).ok() +} diff --git a/kani-compiler/src/kani_middle/transform/check_values.rs b/kani-compiler/src/kani_middle/transform/check_values.rs new file mode 100644 index 000000000000..aefc20b46a44 --- /dev/null +++ b/kani-compiler/src/kani_middle/transform/check_values.rs @@ -0,0 +1,924 @@ +// Copyright Kani Contributors +// SPDX-License-Identifier: Apache-2.0 OR MIT +// +//! Implement a transformation pass that instrument the code to detect possible UB due to +//! the generation of an invalid value. +//! +//! This pass highly depend on Rust type layouts. For more details, see: +//! +//! +//! For that, we traverse the function body and look for unsafe operations that may generate +//! invalid values. For each operation found, we add checks to ensure the value is valid. +//! +//! Note: There is some redundancy in the checks that could be optimized. Example: +//! 1. We could merge the invalid values by the offset. +//! 2. We could avoid checking places that have been checked before. +use crate::args::ExtraChecks; +use crate::kani_middle::transform::body::{CheckType, MutableBody, SourceInstruction}; +use crate::kani_middle::transform::check_values::SourceOp::UnsupportedCheck; +use crate::kani_middle::transform::{TransformPass, TransformationType}; +use crate::kani_queries::QueryDb; +use rustc_middle::ty::TyCtxt; +use rustc_smir::rustc_internal; +use stable_mir::abi::{FieldsShape, Scalar, TagEncoding, ValueAbi, VariantsShape, WrappingRange}; +use stable_mir::mir::mono::{Instance, InstanceKind}; +use stable_mir::mir::visit::{Location, PlaceContext, PlaceRef}; +use stable_mir::mir::{ + AggregateKind, BasicBlockIdx, BinOp, Body, CastKind, Constant, FieldIdx, Local, LocalDecl, + MirVisitor, Mutability, NonDivergingIntrinsic, Operand, Place, ProjectionElem, Rvalue, + Statement, StatementKind, Terminator, TerminatorKind, +}; +use stable_mir::target::{MachineInfo, MachineSize}; +use stable_mir::ty::{AdtKind, Const, IndexedVal, RigidTy, Ty, TyKind, UintTy}; +use stable_mir::CrateDef; +use std::fmt::{Debug, Formatter}; +use strum_macros::AsRefStr; +use tracing::{debug, trace}; + +/// Instrument the code with checks for invalid values. +pub struct ValidValuePass { + check_type: CheckType, +} + +impl ValidValuePass { + pub fn new(tcx: TyCtxt) -> Self { + ValidValuePass { check_type: CheckType::new(tcx) } + } +} + +impl TransformPass for ValidValuePass { + fn transformation_type() -> TransformationType + where + Self: Sized, + { + TransformationType::Instrumentation + } + + fn is_enabled(&self, query_db: &QueryDb) -> bool + where + Self: Sized, + { + let args = query_db.args(); + args.ub_check.contains(&ExtraChecks::Validity) + } + + /// Transform the function body by inserting checks one-by-one. + /// For every unsafe dereference or a transmute operation, we check all values are valid. + fn transform(&self, tcx: TyCtxt, body: Body, instance: Instance) -> (bool, Body) { + trace!(function=?instance.name(), "transform"); + let mut new_body = MutableBody::from(body); + let orig_len = new_body.blocks().len(); + // Do not cache body.blocks().len() since it will change as we add new checks. + for bb_idx in 0..new_body.blocks().len() { + let Some(candidate) = + CheckValueVisitor::find_next(&new_body, bb_idx, bb_idx >= orig_len) + else { + continue; + }; + self.build_check(tcx, &mut new_body, candidate); + } + (orig_len != new_body.blocks().len(), new_body.into()) + } +} + +impl Debug for ValidValuePass { + /// Implement manually since MachineInfo doesn't currently derive Debug. + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + "ValidValuePass".fmt(f) + } +} + +impl ValidValuePass { + fn build_check(&self, tcx: TyCtxt, body: &mut MutableBody, instruction: UnsafeInstruction) { + debug!(?instruction, "build_check"); + let mut source = instruction.source; + for operation in instruction.operations { + match operation { + SourceOp::BytesValidity { ranges, target_ty, rvalue } => { + let value = body.new_assignment(rvalue, &mut source); + let rvalue_ptr = Rvalue::AddressOf(Mutability::Not, Place::from(value)); + for range in ranges { + let result = + self.build_limits(body, &range, rvalue_ptr.clone(), &mut source); + let msg = format!( + "Undefined Behavior: Invalid value of type `{}`", + // TODO: Fix pretty_ty + rustc_internal::internal(tcx, target_ty) + ); + body.add_check(tcx, &self.check_type, &mut source, result, &msg); + } + } + SourceOp::DerefValidity { pointee_ty, rvalue, ranges } => { + for range in ranges { + let result = self.build_limits(body, &range, rvalue.clone(), &mut source); + let msg = format!( + "Undefined Behavior: Invalid value of type `{}`", + // TODO: Fix pretty_ty + rustc_internal::internal(tcx, pointee_ty) + ); + body.add_check(tcx, &self.check_type, &mut source, result, &msg); + } + } + SourceOp::UnsupportedCheck { check, ty } => { + let reason = format!( + "Kani currently doesn't support checking validity of `{check}` for `{}` type", + rustc_internal::internal(tcx, ty) + ); + self.unsupported_check(tcx, body, &mut source, &reason); + } + } + } + } + + fn build_limits( + &self, + body: &mut MutableBody, + req: &ValidValueReq, + rvalue_ptr: Rvalue, + source: &mut SourceInstruction, + ) -> Local { + let span = source.span(body.blocks()); + debug!(?req, ?rvalue_ptr, ?span, "build_limits"); + let primitive_ty = uint_ty(req.size.bytes()); + let start_const = body.new_const_operand(req.valid_range.start, primitive_ty, span); + let end_const = body.new_const_operand(req.valid_range.end, primitive_ty, span); + let orig_ptr = if req.offset != 0 { + let start_ptr = move_local(body.new_assignment(rvalue_ptr, source)); + let byte_ptr = move_local(body.new_cast_ptr( + start_ptr, + Ty::unsigned_ty(UintTy::U8), + Mutability::Not, + source, + )); + let offset_const = body.new_const_operand(req.offset as _, UintTy::Usize, span); + let offset = move_local(body.new_assignment(Rvalue::Use(offset_const), source)); + move_local(body.new_binary_op(BinOp::Offset, byte_ptr, offset, source)) + } else { + move_local(body.new_assignment(rvalue_ptr, source)) + }; + let value_ptr = + body.new_cast_ptr(orig_ptr, Ty::unsigned_ty(primitive_ty), Mutability::Not, source); + let value = + Operand::Copy(Place { local: value_ptr, projection: vec![ProjectionElem::Deref] }); + let start_result = body.new_binary_op(BinOp::Ge, value.clone(), start_const, source); + let end_result = body.new_binary_op(BinOp::Le, value, end_const, source); + if req.valid_range.wraps_around() { + // valid >= start || valid <= end + body.new_binary_op( + BinOp::BitOr, + move_local(start_result), + move_local(end_result), + source, + ) + } else { + // valid >= start && valid <= end + body.new_binary_op( + BinOp::BitAnd, + move_local(start_result), + move_local(end_result), + source, + ) + } + } + + fn unsupported_check( + &self, + tcx: TyCtxt, + body: &mut MutableBody, + source: &mut SourceInstruction, + reason: &str, + ) { + let span = source.span(body.blocks()); + let rvalue = Rvalue::Use(Operand::Constant(Constant { + literal: Const::from_bool(false), + span, + user_ty: None, + })); + let result = body.new_assignment(rvalue, source); + body.add_check(tcx, &self.check_type, source, result, reason); + } +} + +fn move_local(local: Local) -> Operand { + Operand::Move(Place::from(local)) +} + +fn uint_ty(bytes: usize) -> UintTy { + match bytes { + 1 => UintTy::U8, + 2 => UintTy::U16, + 4 => UintTy::U32, + 8 => UintTy::U64, + 16 => UintTy::U128, + _ => unreachable!("Unexpected size: {bytes}"), + } +} + +/// Represent a requirement for the value stored in the given offset. +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +struct ValidValueReq { + /// Offset in bytes. + offset: usize, + /// Size of this requirement. + size: MachineSize, + /// The range restriction is represented by a Scalar. + valid_range: WrappingRange, +} + +// TODO: Optimize checks by merging requirements whenever possible. +// There are a few cases that would need to be cover: +// 1- Ranges intersection is the same as one of the ranges (or both). +// 2- Ranges intersection is a new valid range. +// 3- Ranges intersection is a combination of two new ranges. +// 4- Intersection is empty. +impl ValidValueReq { + /// Only a type with `ValueAbi::Scalar` and `ValueAbi::ScalarPair` can be directly assigned an + /// invalid value directly. + /// + /// It's not possible to define a `rustc_layout_scalar_valid_range_*` to any other structure. + /// Note that this annotation only applies to the first scalar in the layout. + pub fn try_from_ty(machine_info: &MachineInfo, ty: Ty) -> Option { + let shape = ty.layout().unwrap().shape(); + match shape.abi { + ValueAbi::Scalar(Scalar::Initialized { value, valid_range }) + | ValueAbi::ScalarPair(Scalar::Initialized { value, valid_range }, _) => { + Some(ValidValueReq { offset: 0, size: value.size(machine_info), valid_range }) + } + ValueAbi::Scalar(_) + | ValueAbi::ScalarPair(_, _) + | ValueAbi::Uninhabited + | ValueAbi::Vector { .. } + | ValueAbi::Aggregate { .. } => None, + } + } + + /// Check if range is full. + pub fn is_full(&self) -> bool { + self.valid_range.is_full(self.size).unwrap() + } + + /// Check if this range contains `other` range. + /// + /// I.e., `scalar_2` ⊆ `scalar_1` + pub fn contains(&self, other: &ValidValueReq) -> bool { + assert_eq!(self.size, other.size); + match (self.valid_range.wraps_around(), other.valid_range.wraps_around()) { + (true, true) | (false, false) => { + self.valid_range.start <= other.valid_range.start + && self.valid_range.end >= other.valid_range.end + } + (true, false) => { + self.valid_range.start <= other.valid_range.start + || self.valid_range.end >= other.valid_range.end + } + (false, true) => self.is_full(), + } + } +} + +#[derive(AsRefStr, Clone, Debug)] +enum SourceOp { + /// Validity checks are done on a byte level when the Rvalue can generate invalid value. + /// + /// This variant tracks a location that is valid for its current type, but it may not be + /// valid for the given location in target type. This happens for: + /// - Transmute + /// - Field assignment + /// - Aggregate assignment + /// - Union Access + /// + /// Each range is a pair of offset and scalar that represents the valid values. + /// Note that the same offset may have multiple ranges that may require being joined. + BytesValidity { target_ty: Ty, rvalue: Rvalue, ranges: Vec }, + + /// Similar to BytesValidity, but it stores any dereference that may be unsafe. + /// + /// This can happen for: + /// - Raw pointer dereference + DerefValidity { pointee_ty: Ty, rvalue: Rvalue, ranges: Vec }, + + /// Represents a range check Kani currently does not support. + /// + /// This will translate into an assertion failure with an unsupported message. + /// There are many corner cases with the usage of #[rustc_layout_scalar_valid_range_*] + /// attribute. Such as valid ranges that do not intersect or enumeration with variants + /// with niche. + /// + /// Supporting all cases require significant work, and it is unlikely to exist in real world + /// code. To be on the sound side, we just emit an unsupported check, and users will need to + /// disable the check in person, and create a feature request for their case. + /// + /// TODO: Consider replacing the assertion(false) by an unsupported operation that emits a + /// compilation warning. + UnsupportedCheck { check: String, ty: Ty }, +} + +/// The unsafe instructions that may generate invalid values. +/// We need to instrument all operations to ensure the instruction is safe. +#[derive(Clone, Debug)] +struct UnsafeInstruction { + /// The instruction that depends on the potentially invalid value. + source: SourceInstruction, + /// The unsafe operations that may cause an invalid value in this instruction. + operations: Vec, +} + +/// Extract any source that may potentially trigger UB due to the generation of an invalid value. +/// +/// Generating an invalid value requires an unsafe operation, however, in MIR, it +/// may just be represented as a regular assignment. +/// +/// Thus, we have to instrument every assignment to an object that has niche and that the source +/// is an object of a different source, e.g.: +/// - Aggregate assignment +/// - Transmute +/// - MemCopy +/// - Cast +struct CheckValueVisitor<'a> { + locals: &'a [LocalDecl], + /// Whether we should skip the next instruction, since it might've been instrumented already. + /// When we instrument an instruction, we partition the basic block, and the instruction that + /// may trigger UB becomes the first instruction of the basic block, which we need to skip + /// later. + skip_next: bool, + /// The instruction being visited at a given point. + current: SourceInstruction, + /// The target instruction that should be verified. + pub target: Option, + /// The basic block being visited. + bb: BasicBlockIdx, + /// Machine information needed to calculate Niche. + machine: MachineInfo, +} + +impl<'a> CheckValueVisitor<'a> { + fn find_next( + body: &'a MutableBody, + bb: BasicBlockIdx, + skip_first: bool, + ) -> Option { + let mut visitor = CheckValueVisitor { + locals: body.locals(), + skip_next: skip_first, + current: SourceInstruction::Statement { idx: 0, bb }, + target: None, + bb, + machine: MachineInfo::target(), + }; + visitor.visit_basic_block(&body.blocks()[bb]); + visitor.target + } + + fn push_target(&mut self, op: SourceOp) { + let target = self + .target + .get_or_insert_with(|| UnsafeInstruction { source: self.current, operations: vec![] }); + target.operations.push(op); + } +} + +impl<'a> MirVisitor for CheckValueVisitor<'a> { + fn visit_statement(&mut self, stmt: &Statement, location: Location) { + if self.skip_next { + self.skip_next = false; + } else if self.target.is_none() { + // Leave it as an exhaustive match to be notified when a new kind is added. + match &stmt.kind { + StatementKind::Intrinsic(NonDivergingIntrinsic::CopyNonOverlapping(_)) => { + // Source and destination have the same type, so no invalid value cannot be + // generated. + } + StatementKind::Assign(place, rvalue) => { + // First check rvalue. + self.super_statement(stmt, location); + // Then check the destination place. + let ranges = assignment_check_points( + &self.machine, + self.locals, + place, + rvalue.ty(self.locals).unwrap(), + ); + if !ranges.is_empty() { + self.push_target(SourceOp::BytesValidity { + target_ty: self.locals[place.local].ty, + rvalue: rvalue.clone(), + ranges, + }); + } + } + StatementKind::FakeRead(_, _) + | StatementKind::SetDiscriminant { .. } + | StatementKind::Deinit(_) + | StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Retag(_, _) + | StatementKind::PlaceMention(_) + | StatementKind::AscribeUserType { .. } + | StatementKind::Coverage(_) + | StatementKind::ConstEvalCounter + | StatementKind::Intrinsic(NonDivergingIntrinsic::Assume(_)) + | StatementKind::Nop => self.super_statement(stmt, location), + } + } + + let SourceInstruction::Statement { idx, bb } = self.current else { unreachable!() }; + self.current = SourceInstruction::Statement { idx: idx + 1, bb }; + } + fn visit_terminator(&mut self, term: &Terminator, location: Location) { + if !(self.skip_next || self.target.is_some()) { + self.current = SourceInstruction::Terminator { bb: self.bb }; + // Leave it as an exhaustive match to be notified when a new kind is added. + match &term.kind { + TerminatorKind::Call { func, args, .. } => { + // Note: For transmute, both Src and Dst must be valid type. + // In this case, we need to save the Dst, and invoke super_terminator. + self.super_terminator(term, location); + let instance = expect_instance(self.locals, func); + if instance.kind == InstanceKind::Intrinsic { + match instance.intrinsic_name().unwrap().as_str() { + "write_bytes" => { + // The write bytes intrinsic may trigger UB in safe code. + // pub unsafe fn write_bytes(dst: *mut T, val: u8, count: usize) + // + // We don't support this operation yet. + let TyKind::RigidTy(RigidTy::RawPtr(target_ty, Mutability::Mut)) = + args[0].ty(self.locals).unwrap().kind() + else { + unreachable!() + }; + let validity = ty_validity_per_offset(&self.machine, target_ty, 0); + match validity { + Ok(ranges) if ranges.is_empty() => {} + _ => self.push_target(SourceOp::UnsupportedCheck { + check: "write_bytes".to_string(), + ty: target_ty, + }), + } + } + "transmute" | "transmute_copy" => { + unreachable!("Should've been lowered") + } + _ => {} + } + } + } + TerminatorKind::Goto { .. } + | TerminatorKind::SwitchInt { .. } + | TerminatorKind::Resume + | TerminatorKind::Abort + | TerminatorKind::Return + | TerminatorKind::Unreachable + | TerminatorKind::Drop { .. } + | TerminatorKind::Assert { .. } + | TerminatorKind::InlineAsm { .. } => self.super_terminator(term, location), + } + } + } + + fn visit_place(&mut self, place: &Place, ptx: PlaceContext, location: Location) { + for (idx, elem) in place.projection.iter().enumerate() { + let place_ref = PlaceRef { local: place.local, projection: &place.projection[..idx] }; + match elem { + ProjectionElem::Deref => { + let ptr_ty = place_ref.ty(self.locals).unwrap(); + if ptr_ty.kind().is_raw_ptr() { + let target_ty = elem.ty(ptr_ty).unwrap(); + let validity = ty_validity_per_offset(&self.machine, target_ty, 0); + match validity { + Ok(ranges) if !ranges.is_empty() => { + self.push_target(SourceOp::DerefValidity { + pointee_ty: target_ty, + rvalue: Rvalue::Use( + Operand::Copy(Place { + local: place_ref.local, + projection: place_ref.projection.to_vec(), + }) + .clone(), + ), + ranges, + }) + } + Err(_msg) => self.push_target(SourceOp::UnsupportedCheck { + check: "raw pointer dereference".to_string(), + ty: target_ty, + }), + _ => {} + } + } + } + ProjectionElem::Field(idx, target_ty) => { + if target_ty.kind().is_union() + && (!ptx.is_mutating() || place.projection.len() > idx + 1) + { + let validity = ty_validity_per_offset(&self.machine, *target_ty, 0); + match validity { + Ok(ranges) if !ranges.is_empty() => { + self.push_target(SourceOp::BytesValidity { + target_ty: *target_ty, + rvalue: Rvalue::Use(Operand::Copy(Place { + local: place_ref.local, + projection: place_ref.projection.to_vec(), + })), + ranges, + }) + } + Err(_msg) => self.push_target(SourceOp::UnsupportedCheck { + check: "union access".to_string(), + ty: *target_ty, + }), + _ => {} + } + } + } + ProjectionElem::Downcast(_) => {} + ProjectionElem::OpaqueCast(_) => {} + ProjectionElem::Subtype(_) => {} + ProjectionElem::Index(_) + | ProjectionElem::ConstantIndex { .. } + | ProjectionElem::Subslice { .. } => { /* safe */ } + } + } + self.super_place(place, ptx, location) + } + + fn visit_rvalue(&mut self, rvalue: &Rvalue, location: Location) { + match rvalue { + Rvalue::Cast(kind, op, dest_ty) => match kind { + CastKind::PtrToPtr => { + // For mutable raw pointer, if the type we are casting to is less restrictive + // than the original type, writing to the pointer could generate UB if the + // value is ever read again using the original pointer. + let TyKind::RigidTy(RigidTy::RawPtr(dest_pointee_ty, Mutability::Mut)) = + dest_ty.kind() + else { + // We only care about *mut T as *mut U + return; + }; + let src_ty = op.ty(self.locals).unwrap(); + debug!(?src_ty, ?dest_ty, "visit_rvalue mutcast"); + let TyKind::RigidTy(RigidTy::RawPtr(src_pointee_ty, _)) = src_ty.kind() else { + unreachable!() + }; + if let Ok(src_validity) = + ty_validity_per_offset(&self.machine, src_pointee_ty, 0) + { + if !src_validity.is_empty() { + if let Ok(dest_validity) = + ty_validity_per_offset(&self.machine, dest_pointee_ty, 0) + { + if dest_validity != src_validity { + self.push_target(SourceOp::UnsupportedCheck { + check: "mutable cast".to_string(), + ty: src_ty, + }) + } + } else { + self.push_target(SourceOp::UnsupportedCheck { + check: "mutable cast".to_string(), + ty: *dest_ty, + }) + } + } + } else { + self.push_target(SourceOp::UnsupportedCheck { + check: "mutable cast".to_string(), + ty: src_ty, + }) + } + } + CastKind::Transmute => { + debug!(?dest_ty, "transmute"); + // For transmute, we care about the destination type only. + // This could be optimized to only add a check if the requirements of the + // destination type are stricter than the source. + if let Ok(dest_validity) = ty_validity_per_offset(&self.machine, *dest_ty, 0) { + trace!(?dest_validity, "transmute"); + if !dest_validity.is_empty() { + self.push_target(SourceOp::BytesValidity { + target_ty: *dest_ty, + rvalue: rvalue.clone(), + ranges: dest_validity, + }) + } + } else { + self.push_target(SourceOp::UnsupportedCheck { + check: "transmute".to_string(), + ty: *dest_ty, + }) + } + } + CastKind::DynStar => self.push_target(UnsupportedCheck { + check: "Dyn*".to_string(), + ty: (rvalue.ty(self.locals).unwrap()), + }), + CastKind::PointerExposeAddress + | CastKind::PointerFromExposedAddress + | CastKind::PointerCoercion(_) + | CastKind::IntToInt + | CastKind::FloatToInt + | CastKind::FloatToFloat + | CastKind::IntToFloat + | CastKind::FnPtrToPtr => {} + }, + Rvalue::ShallowInitBox(_, _) => { + // The contents of the box is considered uninitialized. + // This should already be covered by the Assign detection. + } + Rvalue::Aggregate(kind, operands) => match kind { + // If the aggregated structure has invalid value, this could generate invalid value. + // But only if the operands don't have the exact same restrictions. + // This happens today with the usage of `rustc_layout_scalar_valid_range_*` + // attributes. + // In this case, only the value of the first member in memory can be restricted, + // thus, we only need to check the operand used to assign to the first in memory + // field. + AggregateKind::Adt(def, _variant, args, _, _) => { + if def.kind() == AdtKind::Struct { + let dest_ty = Ty::from_rigid_kind(RigidTy::Adt(*def, args.clone())); + if let Some(req) = ValidValueReq::try_from_ty(&self.machine, dest_ty) + && !req.is_full() + { + let dest_layout = dest_ty.layout().unwrap().shape(); + let first_op = + first_aggregate_operand(dest_ty, &dest_layout.fields, operands); + let first_ty = first_op.ty(self.locals).unwrap(); + // Rvalue must have same Abi layout except for range. + if !req.contains( + &ValidValueReq::try_from_ty(&self.machine, first_ty).unwrap(), + ) { + self.push_target(SourceOp::BytesValidity { + target_ty: dest_ty, + rvalue: Rvalue::Use(first_op), + ranges: vec![req], + }) + } + } + } + } + // Only aggregate value. + AggregateKind::Array(_) + | AggregateKind::Closure(_, _) + | AggregateKind::Coroutine(_, _, _) + | AggregateKind::Tuple => {} + }, + Rvalue::AddressOf(_, _) + | Rvalue::BinaryOp(_, _, _) + | Rvalue::CheckedBinaryOp(_, _, _) + | Rvalue::CopyForDeref(_) + | Rvalue::Discriminant(_) + | Rvalue::Len(_) + | Rvalue::Ref(_, _, _) + | Rvalue::Repeat(_, _) + | Rvalue::ThreadLocalRef(_) + | Rvalue::NullaryOp(_, _) + | Rvalue::UnaryOp(_, _) + | Rvalue::Use(_) => {} + } + self.super_rvalue(rvalue, location); + } +} + +/// Gets the operand that corresponds to the assignment of the first sized field in memory. +/// +/// The first field of a structure is the only one that can have extra value restrictions imposed +/// by `rustc_layout_scalar_valid_range_*` attributes. +/// +/// Note: This requires at least one operand to be sized and there's a 1:1 match between operands +/// and field types. +fn first_aggregate_operand(dest_ty: Ty, dest_shape: &FieldsShape, operands: &[Operand]) -> Operand { + let Some(first) = first_sized_field_idx(dest_ty, dest_shape) else { unreachable!() }; + operands[first].clone() +} + +/// Index of the first non_1zst fields in memory order. +fn first_sized_field_idx(ty: Ty, shape: &FieldsShape) -> Option { + if let TyKind::RigidTy(RigidTy::Adt(adt_def, args)) = ty.kind() + && adt_def.kind() == AdtKind::Struct + { + let offset_order = shape.fields_by_offset_order(); + let fields = adt_def.variants_iter().next().unwrap().fields(); + offset_order + .into_iter() + .find(|idx| !fields[*idx].ty_with_args(&args).layout().unwrap().shape().is_1zst()) + } else { + None + } +} + +/// An assignment to a field with invalid values is unsafe, and it may trigger UB if +/// the assigned value is invalid. +/// +/// This can only happen to the first in memory sized field of a struct, and only if the field +/// type invalid range is a valid value for the rvalue type. +fn assignment_check_points( + machine_info: &MachineInfo, + locals: &[LocalDecl], + place: &Place, + rvalue_ty: Ty, +) -> Vec { + let mut ty = locals[place.local].ty; + let Some(rvalue_range) = ValidValueReq::try_from_ty(machine_info, rvalue_ty) else { + // Rvalue Abi must be Scalar / ScalarPair since destination must be Scalar / ScalarPair. + return vec![]; + }; + let mut invalid_ranges = vec![]; + for proj in &place.projection { + match proj { + ProjectionElem::Field(field_idx, field_ty) => { + let shape = ty.layout().unwrap().shape(); + if first_sized_field_idx(ty, &shape.fields) == Some(*field_idx) + && let Some(dest_valid) = ValidValueReq::try_from_ty(machine_info, ty) + && !dest_valid.is_full() + && dest_valid.size == rvalue_range.size + { + if !dest_valid.contains(&rvalue_range) { + invalid_ranges.push(dest_valid) + } + } else { + // Invalidate collected ranges so far since we are no longer in the path of + // the first element. + invalid_ranges.clear(); + } + ty = *field_ty; + } + ProjectionElem::Deref + | ProjectionElem::Index(_) + | ProjectionElem::ConstantIndex { .. } + | ProjectionElem::Subslice { .. } + | ProjectionElem::Downcast(_) + | ProjectionElem::OpaqueCast(_) + | ProjectionElem::Subtype(_) => ty = proj.ty(ty).unwrap(), + }; + } + invalid_ranges +} + +/// Retrieve instance for the given function operand. +/// +/// This will panic if the operand is not a function or if it cannot be resolved. +fn expect_instance(locals: &[LocalDecl], func: &Operand) -> Instance { + let ty = func.ty(locals).unwrap(); + match ty.kind() { + TyKind::RigidTy(RigidTy::FnDef(def, args)) => Instance::resolve(def, &args).unwrap(), + _ => unreachable!(), + } +} + +/// Traverse the type and find all invalid values and their location in memory. +/// +/// Not all values are currently supported. For those not supported, we return Error. +fn ty_validity_per_offset( + machine_info: &MachineInfo, + ty: Ty, + current_offset: usize, +) -> Result, String> { + let layout = ty.layout().unwrap().shape(); + let ty_req = || { + if let Some(mut req) = ValidValueReq::try_from_ty(machine_info, ty) + && !req.is_full() + { + req.offset = current_offset; + vec![req] + } else { + vec![] + } + }; + match layout.fields { + FieldsShape::Primitive => Ok(ty_req()), + FieldsShape::Array { stride, count } if count > 0 => { + let TyKind::RigidTy(RigidTy::Array(elem_ty, _)) = ty.kind() else { unreachable!() }; + let elem_validity = ty_validity_per_offset(machine_info, elem_ty, current_offset)?; + let mut result = vec![]; + if !elem_validity.is_empty() { + for idx in 0..count { + let idx: usize = idx.try_into().unwrap(); + let elem_offset = idx * stride.bytes(); + let mut next_validity = elem_validity + .iter() + .cloned() + .map(|mut req| { + req.offset += elem_offset; + req + }) + .collect::>(); + result.append(&mut next_validity) + } + } + Ok(result) + } + FieldsShape::Arbitrary { ref offsets } => { + match ty.kind().rigid().unwrap() { + RigidTy::Adt(def, args) => { + match def.kind() { + AdtKind::Enum => { + // Support basic enumeration forms + let ty_variants = def.variants(); + match layout.variants { + VariantsShape::Single { index } => { + // Only one variant is reachable. This behaves like a struct. + let fields = ty_variants[index.to_index()].fields(); + let mut fields_validity = vec![]; + for idx in layout.fields.fields_by_offset_order() { + let field_offset = offsets[idx].bytes(); + let field_ty = fields[idx].ty_with_args(&args); + fields_validity.append(&mut ty_validity_per_offset( + machine_info, + field_ty, + field_offset + current_offset, + )?); + } + Ok(fields_validity) + } + VariantsShape::Multiple { + tag_encoding: TagEncoding::Niche { .. }, + .. + } => { + Err(format!("Unsupported Enum `{}` check", def.trimmed_name()))? + } + VariantsShape::Multiple { variants, .. } => { + let enum_validity = ty_req(); + let mut fields_validity = vec![]; + for (index, variant) in variants.iter().enumerate() { + let fields = ty_variants[index].fields(); + for field_idx in variant.fields.fields_by_offset_order() { + let field_offset = offsets[field_idx].bytes(); + let field_ty = fields[field_idx].ty_with_args(&args); + fields_validity.append(&mut ty_validity_per_offset( + machine_info, + field_ty, + field_offset + current_offset, + )?); + } + } + if fields_validity.is_empty() { + Ok(enum_validity) + } else { + Err(format!( + "Unsupported Enum `{}` check", + def.trimmed_name() + )) + } + } + } + } + AdtKind::Union => unreachable!(), + AdtKind::Struct => { + // If the struct range has niche add that. + let mut struct_validity = ty_req(); + let fields = def.variants_iter().next().unwrap().fields(); + for idx in layout.fields.fields_by_offset_order() { + let field_offset = offsets[idx].bytes(); + let field_ty = fields[idx].ty_with_args(&args); + struct_validity.append(&mut ty_validity_per_offset( + machine_info, + field_ty, + field_offset + current_offset, + )?); + } + Ok(struct_validity) + } + } + } + RigidTy::Tuple(tys) => { + let mut tuple_validity = vec![]; + for idx in layout.fields.fields_by_offset_order() { + let field_offset = offsets[idx].bytes(); + let field_ty = tys[idx]; + tuple_validity.append(&mut ty_validity_per_offset( + machine_info, + field_ty, + field_offset + current_offset, + )?); + } + Ok(tuple_validity) + } + RigidTy::Bool + | RigidTy::Char + | RigidTy::Int(_) + | RigidTy::Uint(_) + | RigidTy::Float(_) + | RigidTy::Never => { + unreachable!("Expected primitive layout for {ty:?}") + } + RigidTy::Str | RigidTy::Slice(_) | RigidTy::Array(_, _) => { + unreachable!("Expected array layout for {ty:?}") + } + RigidTy::RawPtr(_, _) | RigidTy::Ref(_, _, _) => { + // Fat pointer has arbitrary shape. + Ok(ty_req()) + } + RigidTy::FnDef(_, _) + | RigidTy::FnPtr(_) + | RigidTy::Closure(_, _) + | RigidTy::Coroutine(_, _, _) + | RigidTy::CoroutineWitness(_, _) + | RigidTy::Foreign(_) + | RigidTy::Dynamic(_, _, _) => Err(format!("Unsupported {ty:?}")), + } + } + FieldsShape::Union(_) | FieldsShape::Array { .. } => { + /* Anything is valid */ + Ok(vec![]) + } + } +} diff --git a/kani-compiler/src/kani_middle/transform/mod.rs b/kani-compiler/src/kani_middle/transform/mod.rs new file mode 100644 index 000000000000..a6cd17e8c7db --- /dev/null +++ b/kani-compiler/src/kani_middle/transform/mod.rs @@ -0,0 +1,135 @@ +// Copyright Kani Contributors +// SPDX-License-Identifier: Apache-2.0 OR MIT +// +//! This module is responsible for optimizing and instrumenting function bodies. +//! +//! We make transformations on bodies already monomorphized, which allow us to make stronger +//! decisions based on the instance types and constants. +//! +//! The main downside is that some transformation that don't depend on the specialized type may be +//! applied multiple times, one per specialization. +//! +//! Another downside is that these modifications cannot be applied to concrete playback, since they +//! are applied on the top of StableMIR body, which cannot be propagated back to rustc's backend. +//! +//! # Warn +//! +//! For all instrumentation passes, always use exhaustive matches to ensure soundness in case a new +//! case is added. +use crate::kani_middle::transform::check_values::ValidValuePass; +use crate::kani_queries::QueryDb; +use rustc_middle::ty::TyCtxt; +use stable_mir::mir::mono::Instance; +use stable_mir::mir::Body; +use std::collections::HashMap; +use std::fmt::Debug; + +mod body; +mod check_values; + +/// Object used to retrieve a transformed instance body. +/// The transformations to be applied may be controlled by user options. +/// +/// The order however is always the same, we run optimizations first, and instrument the code +/// after. +#[derive(Debug)] +pub struct BodyTransformation { + /// The passes that may optimize the function body. + /// We store them separately from the instrumentation passes because we run the in specific order. + opt_passes: Vec>, + /// The passes that may add safety checks to the function body. + inst_passes: Vec>, + /// Cache transformation results. + cache: HashMap, +} + +impl BodyTransformation { + pub fn new(queries: &QueryDb, tcx: TyCtxt) -> Self { + let mut transformer = BodyTransformation { + opt_passes: vec![], + inst_passes: vec![], + cache: Default::default(), + }; + transformer.add_pass(queries, ValidValuePass::new(tcx)); + transformer + } + + /// Allow the creation of a dummy transformer that doesn't apply any transformation due to + /// the stubbing validation hack (see `collect_and_partition_mono_items` override. + /// Once we move the stubbing logic to a [TransformPass], we should be able to remove this. + pub fn dummy() -> Self { + BodyTransformation { opt_passes: vec![], inst_passes: vec![], cache: Default::default() } + } + + /// Retrieve the body of an instance. + /// + /// Note that this assumes that the instance does have a body since existing consumers already + /// assume that. Use `instance.has_body()` to check if an instance has a body. + pub fn body(&mut self, tcx: TyCtxt, instance: Instance) -> Body { + match self.cache.get(&instance) { + Some(TransformationResult::Modified(body)) => body.clone(), + Some(TransformationResult::NotModified) => instance.body().unwrap(), + None => { + let mut body = instance.body().unwrap(); + let mut modified = false; + for pass in self.opt_passes.iter().chain(self.inst_passes.iter()) { + let result = pass.transform(tcx, body, instance); + modified |= result.0; + body = result.1; + } + + let result = if modified { + TransformationResult::Modified(body.clone()) + } else { + TransformationResult::NotModified + }; + self.cache.insert(instance, result); + body + } + } + } + + fn add_pass(&mut self, query_db: &QueryDb, pass: P) { + if pass.is_enabled(&query_db) { + match P::transformation_type() { + TransformationType::Instrumentation => self.inst_passes.push(Box::new(pass)), + TransformationType::Optimization => { + unreachable!() + } + } + } + } +} + +/// The type of transformation that a pass may perform. +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +enum TransformationType { + /// Should only add assertion checks to ensure the program is correct. + Instrumentation, + /// May replace inefficient code with more performant but equivalent code. + #[allow(dead_code)] + Optimization, +} + +/// A trait to represent transformation passes that can be used to modify the body of a function. +trait TransformPass: Debug { + /// The type of transformation that this pass implements. + fn transformation_type() -> TransformationType + where + Self: Sized; + + fn is_enabled(&self, query_db: &QueryDb) -> bool + where + Self: Sized; + + /// Run a transformation pass in the function body. + fn transform(&self, tcx: TyCtxt, body: Body, instance: Instance) -> (bool, Body); +} + +/// The transformation result. +/// We currently only cache the body of functions that were instrumented. +#[derive(Clone, Debug)] +enum TransformationResult { + Modified(Body), + NotModified, +} diff --git a/kani-driver/src/call_single_file.rs b/kani-driver/src/call_single_file.rs index 4e8086e7e37b..24691943bc0f 100644 --- a/kani-driver/src/call_single_file.rs +++ b/kani-driver/src/call_single_file.rs @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT use anyhow::Result; +use kani_metadata::UnstableFeature; use std::ffi::OsString; use std::path::{Path, PathBuf}; use std::process::Command; @@ -100,6 +101,10 @@ impl KaniSession { flags.push("--coverage-checks".into()); } + if self.args.common_args.unstable_features.contains(UnstableFeature::ValidValueChecks) { + flags.push("--ub-check=validity".into()) + } + flags.extend(self.args.common_args.unstable_features.as_arguments().map(str::to_string)); // This argument will select the Kani flavour of the compiler. It will be removed before diff --git a/kani_metadata/src/unstable.rs b/kani_metadata/src/unstable.rs index 3820f3f2238e..878b468dbdc3 100644 --- a/kani_metadata/src/unstable.rs +++ b/kani_metadata/src/unstable.rs @@ -84,6 +84,9 @@ pub enum UnstableFeature { FunctionContracts, /// Memory predicate APIs. MemPredicates, + /// Automatically check that no invalid value is produced which is considered UB in Rust. + /// Note that this does not include checking uninitialized value. + ValidValueChecks, } impl UnstableFeature { diff --git a/tests/kani/ValidValues/constants.rs b/tests/kani/ValidValues/constants.rs new file mode 100644 index 000000000000..5230e6e5e6cb --- /dev/null +++ b/tests/kani/ValidValues/constants.rs @@ -0,0 +1,40 @@ +// Copyright Kani Contributors +// SPDX-License-Identifier: Apache-2.0 OR MIT +// kani-flags: -Z valid-value-checks +//! Check that Kani can identify UB when it is reading from a constant. +//! Note that this UB will be removed for `-Z mir-opt-level=2` + +#[kani::proof] +fn transmute_valid_bool() { + let _b = unsafe { std::mem::transmute::(1) }; +} + +#[kani::proof] +fn cast_to_valid_char() { + let _c = unsafe { *(&100u32 as *const u32 as *const char) }; +} + +#[kani::proof] +fn cast_to_valid_offset() { + let val = [100u32, 80u32]; + let _c = unsafe { *(&val as *const [u32; 2] as *const [char; 2]) }; +} + +#[kani::proof] +#[kani::should_panic] +fn transmute_invalid_bool() { + let _b = unsafe { std::mem::transmute::(2) }; +} + +#[kani::proof] +#[kani::should_panic] +fn cast_to_invalid_char() { + let _c = unsafe { *(&u32::MAX as *const u32 as *const char) }; +} + +#[kani::proof] +#[kani::should_panic] +fn cast_to_invalid_offset() { + let val = [100u32, u32::MAX]; + let _c = unsafe { *(&val as *const [u32; 2] as *const [char; 2]) }; +} diff --git a/tests/kani/ValidValues/custom_niche.rs b/tests/kani/ValidValues/custom_niche.rs new file mode 100644 index 000000000000..02b5b87bd092 --- /dev/null +++ b/tests/kani/ValidValues/custom_niche.rs @@ -0,0 +1,100 @@ +// Copyright Kani Contributors +// SPDX-License-Identifier: Apache-2.0 OR MIT +// kani-flags: -Z valid-value-checks +//! Check that Kani can identify UB when using niche attribute for a custom operation. +#![feature(rustc_attrs)] + +use std::mem; +use std::mem::size_of; + +/// A possible implementation for a system of rating that defines niche. +/// A Rating represents the number of stars of a given product (1..=5). +#[rustc_layout_scalar_valid_range_start(1)] +#[rustc_layout_scalar_valid_range_end(5)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct Rating { + stars: u8, +} + +impl kani::Arbitrary for Rating { + fn any() -> Self { + let stars = kani::any_where(|s: &u8| *s >= 1 && *s <= 5); + unsafe { Rating { stars } } + } +} + +impl Rating { + /// Buggy version of new. Note that this still creates an invalid Rating. + /// + /// This is because `then_some` eagerly create the Rating value before assessing the condition. + /// Even though the value is never used, it is still considered UB. + pub fn new(value: u8) -> Option { + (value > 0 && value <= 5).then_some(unsafe { Rating { stars: value } }) + } + + pub unsafe fn new_unchecked(stars: u8) -> Rating { + Rating { stars } + } +} + +#[kani::proof] +#[kani::should_panic] +pub fn check_new_with_ub() { + assert_eq!(Rating::new(10), None); +} + +#[kani::proof] +#[kani::should_panic] +pub fn check_unchecked_new_ub() { + let val = kani::any(); + assert_eq!(unsafe { Rating::new_unchecked(val).stars }, val); +} + +#[kani::proof] +#[kani::should_panic] +pub fn check_new_with_ub_limits() { + let stars = kani::any_where(|s: &u8| *s == 0 || *s > 5); + let _ = Rating::new(stars); +} + +#[kani::proof] +#[kani::should_panic] +pub fn check_invalid_dereference() { + let any: u8 = kani::any(); + let _rating: Rating = unsafe { *(&any as *const _ as *const _) }; +} + +#[kani::proof] +#[kani::should_panic] +pub fn check_invalid_transmute() { + let any: u8 = kani::any(); + let _rating: Rating = unsafe { mem::transmute(any) }; +} + +#[kani::proof] +#[kani::should_panic] +pub fn check_invalid_transmute_copy() { + let any: u8 = kani::any(); + let _rating: Rating = unsafe { mem::transmute_copy(&any) }; +} + +#[kani::proof] +#[kani::should_panic] +pub fn check_invalid_increment() { + let mut orig: Rating = kani::any(); + unsafe { orig.stars += 1 }; +} + +#[kani::proof] +pub fn check_valid_increment() { + let mut orig: Rating = kani::any(); + kani::assume(orig.stars < 5); + unsafe { orig.stars += 1 }; +} + +/// Check that the compiler relies on valid value range of Rating to implement niche optimization. +#[kani::proof] +pub fn check_niche() { + assert_eq!(size_of::(), size_of::>()); + assert_eq!(size_of::(), size_of::>>()); +} diff --git a/tests/kani/ValidValues/non_null.rs b/tests/kani/ValidValues/non_null.rs new file mode 100644 index 000000000000..4874b61bf2d0 --- /dev/null +++ b/tests/kani/ValidValues/non_null.rs @@ -0,0 +1,26 @@ +// Copyright Kani Contributors +// SPDX-License-Identifier: Apache-2.0 OR MIT +// kani-flags: -Z valid-value-checks +//! Check that Kani can identify UB when unsafely writing to NonNull. + +use std::num::NonZeroU8; +use std::ptr::{self, NonNull}; + +#[kani::proof] +#[kani::should_panic] +pub fn check_invalid_value() { + let _ = unsafe { NonNull::new_unchecked(ptr::null_mut::()) }; +} + +#[kani::proof] +#[kani::should_panic] +pub fn check_invalid_value_cfg() { + let nn = unsafe { NonNull::new_unchecked(ptr::null_mut::()) }; + // This should be unreachable. TODO: Make this expected test. + assert_ne!(unsafe { nn.as_ref() }, &10); +} + +#[kani::proof] +pub fn check_valid_dangling() { + let _ = unsafe { NonNull::new_unchecked(4 as *mut u32) }; +} diff --git a/tests/kani/ValidValues/write_invalid.rs b/tests/kani/ValidValues/write_invalid.rs new file mode 100644 index 000000000000..05d3705bd69a --- /dev/null +++ b/tests/kani/ValidValues/write_invalid.rs @@ -0,0 +1,37 @@ +// Copyright Kani Contributors +// SPDX-License-Identifier: Apache-2.0 OR MIT +// kani-flags: -Z valid-value-checks + +//! Check that Kani can identify UB after writing an invalid value. +//! Writing invalid bytes is not UB as long as the incorrect value is not read. +//! However, we over-approximate for sake of simplicity and performance. + +use std::num::NonZeroU8; + +#[kani::proof] +#[kani::should_panic] +pub fn write_invalid_bytes_no_ub_with_spurious_cex() { + let mut non_zero: NonZeroU8 = kani::any(); + let dest = &mut non_zero as *mut _; + unsafe { std::intrinsics::write_bytes(dest, 0, 1) }; +} + +#[kani::proof] +#[kani::should_panic] +pub fn write_valid_before_read() { + let mut non_zero: NonZeroU8 = kani::any(); + let mut non_zero_2: NonZeroU8 = kani::any(); + let dest = &mut non_zero as *mut _; + unsafe { std::intrinsics::write_bytes(dest, 0, 1) }; + unsafe { std::intrinsics::write_bytes(dest, non_zero_2.get(), 1) }; + assert_eq!(non_zero, non_zero_2) +} + +#[kani::proof] +#[kani::should_panic] +pub fn read_invalid_is_ub() { + let mut non_zero: NonZeroU8 = kani::any(); + let dest = &mut non_zero as *mut _; + unsafe { std::intrinsics::write_bytes(dest, 0, 1) }; + assert_eq!(non_zero.get(), 0) +}