diff --git a/library/kani_macros/src/derive.rs b/library/kani_macros/src/derive.rs index 7e3dee390330..4e99590fc6a3 100644 --- a/library/kani_macros/src/derive.rs +++ b/library/kani_macros/src/derive.rs @@ -23,7 +23,7 @@ pub fn expand_derive_arbitrary(item: proc_macro::TokenStream) -> proc_macro::Tok let item_name = &derive_item.ident; // Add a bound `T: Arbitrary` to every type parameter T. - let generics = add_trait_bound(derive_item.generics); + let generics = add_trait_bound_arbitrary(derive_item.generics); // Generate an expression to sum up the heap size of each field. let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); @@ -40,7 +40,7 @@ pub fn expand_derive_arbitrary(item: proc_macro::TokenStream) -> proc_macro::Tok } /// Add a bound `T: Arbitrary` to every type parameter T. -fn add_trait_bound(mut generics: Generics) -> Generics { +fn add_trait_bound_arbitrary(mut generics: Generics) -> Generics { generics.params.iter_mut().for_each(|param| { if let GenericParam::Type(type_param) = param { type_param.bounds.push(parse_quote!(kani::Arbitrary)); @@ -165,3 +165,93 @@ fn fn_any_enum(ident: &Ident, data: &DataEnum) -> TokenStream { } } } + +pub fn expand_derive_invariant(item: proc_macro::TokenStream) -> proc_macro::TokenStream { + let derive_item = parse_macro_input!(item as DeriveInput); + let item_name = &derive_item.ident; + + // Add a bound `T: Invariant` to every type parameter T. + let generics = add_trait_bound_invariant(derive_item.generics); + // Generate an expression to sum up the heap size of each field. + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let body = is_safe_body(&item_name, &derive_item.data); + let expanded = quote! { + // The generated implementation. + impl #impl_generics kani::Invariant for #item_name #ty_generics #where_clause { + fn is_safe(&self) -> bool { + #body + } + } + }; + proc_macro::TokenStream::from(expanded) +} + +/// Add a bound `T: Invariant` to every type parameter T. +fn add_trait_bound_invariant(mut generics: Generics) -> Generics { + generics.params.iter_mut().for_each(|param| { + if let GenericParam::Type(type_param) = param { + type_param.bounds.push(parse_quote!(kani::Invariant)); + } + }); + generics +} + +fn is_safe_body(ident: &Ident, data: &Data) -> TokenStream { + match data { + Data::Struct(struct_data) => struct_safe_conjunction(ident, &struct_data.fields), + Data::Enum(_) => { + abort!(Span::call_site(), "Cannot derive `Invariant` for `{}` enum", ident; + note = ident.span() => + "`#[derive(Invariant)]` cannot be used for enums such as `{}`", ident + ) + } + Data::Union(_) => { + abort!(Span::call_site(), "Cannot derive `Invariant` for `{}` union", ident; + note = ident.span() => + "`#[derive(Invariant)]` cannot be used for unions such as `{}`", ident + ) + } + } +} + +/// Generates an expression that is the conjunction of `is_safe` calls for each field in the struct. +fn struct_safe_conjunction(_ident: &Ident, fields: &Fields) -> TokenStream { + match fields { + // Expands to the expression + // `true && self.field1.is_safe() && self.field2.is_safe() && ..` + Fields::Named(ref fields) => { + let safe_calls = fields.named.iter().map(|field| { + let name = &field.ident; + quote_spanned! {field.span()=> + self.#name.is_safe() + } + }); + // An initial value is required for empty structs + safe_calls.fold(quote! { true }, |acc, call| { + quote! { #acc && #call } + }) + } + Fields::Unnamed(ref fields) => { + // Expands to the expression + // `true && self.0.is_safe() && self.1.is_safe() && ..` + let safe_calls = fields.unnamed.iter().enumerate().map(|(i, field)| { + let idx = syn::Index::from(i); + quote_spanned! {field.span()=> + self.#idx.is_safe() + } + }); + // An initial value is required for empty structs + safe_calls.fold(quote! { true }, |acc, call| { + quote! { #acc && #call } + }) + } + // Expands to the expression + // `true` + Fields::Unit => { + quote! { + true + } + } + } +} diff --git a/library/kani_macros/src/lib.rs b/library/kani_macros/src/lib.rs index 0a000910174a..b10b8a74cdc5 100644 --- a/library/kani_macros/src/lib.rs +++ b/library/kani_macros/src/lib.rs @@ -107,6 +107,13 @@ pub fn derive_arbitrary(item: TokenStream) -> TokenStream { derive::expand_derive_arbitrary(item) } +/// Allow users to auto generate Invariant implementations by using `#[derive(Invariant)]` macro. +#[proc_macro_error] +#[proc_macro_derive(Invariant)] +pub fn derive_invariant(item: TokenStream) -> TokenStream { + derive::expand_derive_invariant(item) +} + /// Add a precondition to this function. /// /// This is part of the function contract API, for more general information see diff --git a/tests/expected/derive-invariant/empty_struct/empty_struct.rs b/tests/expected/derive-invariant/empty_struct/empty_struct.rs new file mode 100644 index 000000000000..4c931ce8aedc --- /dev/null +++ b/tests/expected/derive-invariant/empty_struct/empty_struct.rs @@ -0,0 +1,37 @@ +// Copyright Kani Contributors +// SPDX-License-Identifier: Apache-2.0 OR MIT + +//! Check that Kani can automatically derive `Invariant` for empty structs. + +extern crate kani; +use kani::Invariant; + +#[derive(kani::Arbitrary)] +#[derive(kani::Invariant)] +struct Void; + +#[derive(kani::Arbitrary)] +#[derive(kani::Invariant)] +struct Void2(()); + +#[derive(kani::Arbitrary)] +#[derive(kani::Invariant)] +struct VoidOfVoid(Void, Void2); + +#[kani::proof] +fn check_empty_struct_invariant_1() { + let void1: Void = kani::any(); + assert!(void1.is_safe()); +} + +#[kani::proof] +fn check_empty_struct_invariant_2() { + let void2: Void2 = kani::any(); + assert!(void2.is_safe()); +} + +#[kani::proof] +fn check_empty_struct_invariant_3() { + let void3: VoidOfVoid = kani::any(); + assert!(void3.is_safe()); +} diff --git a/tests/expected/derive-invariant/empty_struct/expected b/tests/expected/derive-invariant/empty_struct/expected new file mode 100644 index 000000000000..8fdca72b1ead --- /dev/null +++ b/tests/expected/derive-invariant/empty_struct/expected @@ -0,0 +1,8 @@ + - Status: SUCCESS\ + - Description: "assertion failed: void1.is_safe()" + + - Status: SUCCESS\ + - Description: "assertion failed: void2.is_safe()" + + - Status: SUCCESS\ + - Description: "assertion failed: void3.is_safe()" diff --git a/tests/expected/derive-invariant/generic_struct/expected b/tests/expected/derive-invariant/generic_struct/expected new file mode 100644 index 000000000000..5e5886bb3e45 --- /dev/null +++ b/tests/expected/derive-invariant/generic_struct/expected @@ -0,0 +1,2 @@ + - Status: SUCCESS\ + - Description: "assertion failed: point.is_safe()" diff --git a/tests/expected/derive-invariant/generic_struct/generic_struct.rs b/tests/expected/derive-invariant/generic_struct/generic_struct.rs new file mode 100644 index 000000000000..91c62fac8ece --- /dev/null +++ b/tests/expected/derive-invariant/generic_struct/generic_struct.rs @@ -0,0 +1,20 @@ +// Copyright Kani Contributors +// SPDX-License-Identifier: Apache-2.0 OR MIT + +//! Check that Kani can automatically derive `Invariant` for structs with generics. + +extern crate kani; +use kani::Invariant; + +#[derive(kani::Arbitrary)] +#[derive(kani::Invariant)] +struct Point { + x: X, + y: Y, +} + +#[kani::proof] +fn check_generic_struct_invariant() { + let point: Point = kani::any(); + assert!(point.is_safe()); +} diff --git a/tests/expected/derive-invariant/invariant_fail/expected b/tests/expected/derive-invariant/invariant_fail/expected new file mode 100644 index 000000000000..511d5901e154 --- /dev/null +++ b/tests/expected/derive-invariant/invariant_fail/expected @@ -0,0 +1,4 @@ + - Status: FAILURE\ + - Description: "assertion failed: wrapper.is_safe()" + +Verification failed for - check_invariant_fail diff --git a/tests/expected/derive-invariant/invariant_fail/invariant_fail.rs b/tests/expected/derive-invariant/invariant_fail/invariant_fail.rs new file mode 100644 index 000000000000..b1d6f8679835 --- /dev/null +++ b/tests/expected/derive-invariant/invariant_fail/invariant_fail.rs @@ -0,0 +1,33 @@ +// Copyright Kani Contributors +// SPDX-License-Identifier: Apache-2.0 OR MIT + +//! Check that a verification failure is triggered when the derived `Invariant` +//! method is checked but not satisfied. + +extern crate kani; +use kani::Invariant; +// Note: This represents an incorrect usage of `Arbitrary` and `Invariant`. +// +// The `Arbitrary` implementation should respect the type invariant, +// but Kani does not enforce this in any way at the moment. +// +#[derive(kani::Arbitrary)] +struct NotNegative(i32); + +impl kani::Invariant for NotNegative { + fn is_safe(&self) -> bool { + self.0 >= 0 + } +} + +#[derive(kani::Arbitrary)] +#[derive(kani::Invariant)] +struct NotNegativeWrapper { + x: NotNegative, +} + +#[kani::proof] +fn check_invariant_fail() { + let wrapper: NotNegativeWrapper = kani::any(); + assert!(wrapper.is_safe()); +} diff --git a/tests/expected/derive-invariant/named_struct/expected b/tests/expected/derive-invariant/named_struct/expected new file mode 100644 index 000000000000..5e5886bb3e45 --- /dev/null +++ b/tests/expected/derive-invariant/named_struct/expected @@ -0,0 +1,2 @@ + - Status: SUCCESS\ + - Description: "assertion failed: point.is_safe()" diff --git a/tests/expected/derive-invariant/named_struct/named_struct.rs b/tests/expected/derive-invariant/named_struct/named_struct.rs new file mode 100644 index 000000000000..7e27404bda11 --- /dev/null +++ b/tests/expected/derive-invariant/named_struct/named_struct.rs @@ -0,0 +1,20 @@ +// Copyright Kani Contributors +// SPDX-License-Identifier: Apache-2.0 OR MIT + +//! Check that Kani can automatically derive `Invariant` for structs with named fields. + +extern crate kani; +use kani::Invariant; + +#[derive(kani::Arbitrary)] +#[derive(kani::Invariant)] +struct Point { + x: i32, + y: i32, +} + +#[kani::proof] +fn check_generic_struct_invariant() { + let point: Point = kani::any(); + assert!(point.is_safe()); +} diff --git a/tests/expected/derive-invariant/unnamed_struct/expected b/tests/expected/derive-invariant/unnamed_struct/expected new file mode 100644 index 000000000000..5e5886bb3e45 --- /dev/null +++ b/tests/expected/derive-invariant/unnamed_struct/expected @@ -0,0 +1,2 @@ + - Status: SUCCESS\ + - Description: "assertion failed: point.is_safe()" diff --git a/tests/expected/derive-invariant/unnamed_struct/unnamed_struct.rs b/tests/expected/derive-invariant/unnamed_struct/unnamed_struct.rs new file mode 100644 index 000000000000..5dee718d05a6 --- /dev/null +++ b/tests/expected/derive-invariant/unnamed_struct/unnamed_struct.rs @@ -0,0 +1,17 @@ +// Copyright Kani Contributors +// SPDX-License-Identifier: Apache-2.0 OR MIT + +//! Check that Kani can automatically derive `Invariant` for structs with unnamed fields. + +extern crate kani; +use kani::Invariant; + +#[derive(kani::Arbitrary)] +#[derive(kani::Invariant)] +struct Point(i32, i32); + +#[kani::proof] +fn check_generic_struct_invariant() { + let point: Point = kani::any(); + assert!(point.is_safe()); +}