diff --git a/src/lib/reasoners/bitlist.ml b/src/lib/reasoners/bitlist.ml index f526d8fcb6..74623b34db 100644 --- a/src/lib/reasoners/bitlist.ml +++ b/src/lib/reasoners/bitlist.ml @@ -289,3 +289,59 @@ let mul a b = in concat (unknown (sz - width mid_bits - width low_bits) Ex.empty) @@ concat mid_bits low_bits + +let shl a b = + (* If the minimum value for [b] is larger than the bitwidth, the result is + zero. + + Otherwise, any low zero bit in [a] is also a zero bit in the result, and + the minimum value for [b] also accounts for that many minimum zeros (e.g. + ?000 shifted by at least 2 has at least 5 low zeroes). + + NB: we would like to use the lower bound from the interval domain for [b] + here. *) + match Z.to_int (increase_lower_bound b Z.zero) with + | n when n < width a -> + let low_zeros = Z.trailing_zeros @@ Z.lognot @@ a.bits_clr in + if low_zeros + n >= width a then + exact (width a) Z.zero (Ex.union (explanation a) (explanation b)) + else if low_zeros + n > 0 then + concat (unknown (width a - low_zeros - n) Ex.empty) @@ + exact (low_zeros + n) Z.zero (Ex.union (explanation a) (explanation b)) + else + unknown (width a) Ex.empty + | _ | exception Z.Overflow -> + exact (width a) Z.zero (explanation b) + +let lshr a b = + (* If the minimum value for [b] is larger than the bitwidth, the result is + zero. + + Otherwise, any high zero bit in [a] is also a zero bit in the result, and + the minimum value for [b] also accounts for that many minimum zeros (e.g. + 000??? shifted by at least 2 is 00000?). + + NB: we would like to use the lower bound from the interval domain for [b] + here. *) + match Z.to_int (increase_lower_bound b Z.zero) with + | n when n < width a -> + let sz = width a in + if Z.testbit a.bits_clr (sz - 1) then (* MSB is zero *) + let low_msb_zero = Z.numbits @@ Z.extract (Z.lognot a.bits_clr) 0 sz in + let nb_msb_zeros = sz - low_msb_zero in + assert (nb_msb_zeros > 0); + let nb_zeros = nb_msb_zeros + n in + if nb_zeros >= sz then + exact sz Z.zero (Ex.union (explanation a) (explanation b)) + else + concat + (exact nb_zeros Z.zero (Ex.union (explanation a) (explanation b))) + (unknown (sz - nb_zeros) Ex.empty) + else if n > 0 then + concat + (exact n Z.zero (explanation b)) + (unknown (sz - n) Ex.empty) + else + unknown sz Ex.empty + | _ | exception Z.Overflow -> + exact (width a) Z.zero (explanation b) diff --git a/src/lib/reasoners/bitlist.mli b/src/lib/reasoners/bitlist.mli index 1a4e22b51a..3341cbd859 100644 --- a/src/lib/reasoners/bitlist.mli +++ b/src/lib/reasoners/bitlist.mli @@ -111,6 +111,12 @@ val logxor : t -> t -> t val mul : t -> t -> t (** Multiplication. *) +val shl : t -> t -> t +(** Logical left shift. *) + +val lshr : t -> t -> t +(** Logical right shift. *) + val concat : t -> t -> t (** Bit-vector concatenation. *) diff --git a/src/lib/reasoners/bitv.ml b/src/lib/reasoners/bitv.ml index 0ffa395ac0..66825c33d9 100644 --- a/src/lib/reasoners/bitv.ml +++ b/src/lib/reasoners/bitv.ml @@ -353,7 +353,8 @@ module Shostak(X : ALIEN) = struct | Op ( Concat | Extract _ | BV2Nat | BVnot | BVand | BVor | BVxor - | BVadd | BVsub | BVmul | BVudiv | BVurem) + | BVadd | BVsub | BVmul | BVudiv | BVurem + | BVshl | BVlshr) -> true | _ -> false @@ -412,6 +413,7 @@ module Shostak(X : ALIEN) = struct | { f = Op ( BVand | BVor | BVxor | BVadd | BVsub | BVmul | BVudiv | BVurem + | BVshl | BVlshr ); _ } -> X.term_embed t, [] | _ -> X.make t diff --git a/src/lib/reasoners/bitv_rel.ml b/src/lib/reasoners/bitv_rel.ml index beaa80f3e8..91af250ae0 100644 --- a/src/lib/reasoners/bitv_rel.ml +++ b/src/lib/reasoners/bitv_rel.ml @@ -254,6 +254,12 @@ module Constraint : sig This uses the convention that [x % 0] is [x]. *) + val bvshl : X.r -> X.r -> X.r -> t + (** [bvshl r x y] is the constraint [r = x << y] *) + + val bvlshr : X.r -> X.r -> X.r -> t + (** [bvshl r x y] is the constraint [r = x >> y] *) + val bvule : X.r -> X.r -> t val bvugt : X.r -> X.r -> t @@ -271,6 +277,8 @@ end = struct | Band | Bor | Bxor (* Arithmetic operations *) | Badd | Bmul | Budiv | Burem + (* Shift operations *) + | Bshl | Blshr let pp_binop ppf = function | Band -> Fmt.pf ppf "bvand" @@ -280,6 +288,8 @@ end = struct | Bmul -> Fmt.pf ppf "bvmul" | Budiv -> Fmt.pf ppf "bvudiv" | Burem -> Fmt.pf ppf "bvurem" + | Bshl -> Fmt.pf ppf "bvshl" + | Blshr -> Fmt.pf ppf "bvlshr" let equal_binop : binop -> binop -> bool = Stdlib.(=) @@ -287,7 +297,7 @@ end = struct let is_commutative = function | Band | Bor | Bxor | Badd | Bmul -> true - | Budiv | Burem -> false + | Budiv | Burem | Bshl | Blshr -> false let propagate_binop ~ex dx op dy dz = let open Domains.Ephemeral in @@ -321,6 +331,12 @@ end = struct (* TODO: full adder propagation *) () + | Bshl -> (* Only forward propagation for now *) + update ~ex dx (Bitlist.shl !!dy !!dz) + + | Blshr -> (* Only forward propagation for now *) + update ~ex dx (Bitlist.lshr !!dy !!dz) + | Bmul -> (* Only forward propagation for now *) update ~ex dx (Bitlist.mul !!dy !!dz) @@ -339,6 +355,12 @@ end = struct update ~ex dy @@ norm @@ Intervals.Int.sub !!dr !!dx; update ~ex dx @@ norm @@ Intervals.Int.sub !!dr !!dy + | Bshl -> (* Only forward propagation for now *) + update ~ex dr @@ Intervals.Int.bvshl ~size:sz !!dx !!dy + + | Blshr -> (* Only forward propagation for now *) + update ~ex dr @@ Intervals.Int.lshr !!dx !!dy + | Bmul -> (* Only forward propagation for now *) update ~ex dr @@ norm @@ Intervals.Int.mul !!dx !!dy @@ -552,6 +574,8 @@ end = struct let bvmul = cbinop Bmul let bvudiv = cbinop Budiv let bvurem = cbinop Burem + let bvshl = cbinop Bshl + let bvlshr = cbinop Blshr let crel r = hcons @@ Crel r @@ -707,6 +731,27 @@ end = struct ) else false + (* Add the constraint: r = x >> c *) + let add_lshr_const acts r x c = + let sz = bitwidth r in + match Z.to_int c with + | 0 -> add_eq acts r x + | n when n < sz -> + assert (n > 0); + let r_bitv = Shostak.Bitv.embed r in + let low_bits = + Shostak.Bitv.is_mine @@ + Bitv.extract sz n (sz - 1) (Shostak.Bitv.embed x) + in + add_eq acts + (Shostak.Bitv.is_mine @@ Bitv.extract sz 0 (sz - 1 - n) r_bitv) + low_bits; + add_eq_const acts + (Shostak.Bitv.is_mine @@ Bitv.extract sz (sz - n) (sz - 1) r_bitv) + Z.zero + | _ | exception Z.Overflow -> + add_eq_const acts r Z.zero + (* Ground evaluation rules for binary operators. *) let eval_binop op ty x y = match op with @@ -725,6 +770,18 @@ end = struct cast ty x else cast ty @@ Z.rem x y + | Bshl -> ( + match ty, Z.to_int y with + | Tbitv sz, y when y < sz -> + cast ty @@ Z.shift_left x y + | _ | exception Z.Overflow -> cast ty Z.zero + ) + | Blshr -> ( + match ty, Z.to_int y with + | Tbitv sz, y when y < sz -> + cast ty @@ Z.shift_right x y + | _ | exception Z.Overflow -> cast ty Z.zero + ) (* Constant simplification rules for binary operators. @@ -771,6 +828,17 @@ end = struct | Budiv | Burem -> false + (* shifts becomes a simple extraction when we know the right-hand side *) + | Bshl when X.is_constant y -> + add_shl_const acts r x (value y); + true + | Bshl -> false + + | Blshr when X.is_constant y -> + add_lshr_const acts r x (value y); + true + | Blshr -> false + (* Algebraic rewrite rules for binary operators. Rules based on constant simplifications are in [rw_binop_const]. *) @@ -842,6 +910,8 @@ let extract_binop = | BVmul -> Some bvmul | BVudiv -> Some bvudiv | BVurem -> Some bvurem + | BVshl -> Some bvshl + | BVlshr -> Some bvlshr | _ -> None let extract_constraints bcs uf r t = diff --git a/src/lib/reasoners/intervals.ml b/src/lib/reasoners/intervals.ml index 573c62d004..2180b59448 100644 --- a/src/lib/reasoners/intervals.ml +++ b/src/lib/reasoners/intervals.ml @@ -314,6 +314,40 @@ module ZEuclideanType = struct | Neg_infinite -> Pos_infinite | Pos_infinite -> Neg_infinite | Finite n -> Finite (Z.lognot n) + + (* Any value higher than [size] maps to [+oo] to avoid huge terms. *) + let shift_left ?(size = max_int) x y = + match y with + | Neg_infinite -> + Fmt.invalid_arg "shl: must shift by nonnegative amount" + | Pos_infinite -> Pos_infinite + | Finite y when Z.sign y < 0 -> + Fmt.invalid_arg "shl: must shift by nonnegative amount" + | Finite y -> + match Z.to_int y with + | exception Z.Overflow -> Pos_infinite + | y when y >= size -> Pos_infinite + | y -> + match x with + | Neg_infinite -> Neg_infinite + | Pos_infinite -> Pos_infinite + | Finite x -> Finite (Z.shift_left x y) + + let shift_right x y = + match x, y with + | _, Neg_infinite -> + Fmt.invalid_arg "shift_right: must shift by nonnegative amount" + | _, Finite y when Z.sign y < 0 -> + Fmt.invalid_arg "shift_right: must shift by nonnegative amount" + | Pos_infinite, Pos_infinite -> + Fmt.invalid_arg "shift_right: undefined limit" + | _, Pos_infinite -> zero + | Neg_infinite, Finite _ -> Neg_infinite + | Pos_infinite, Finite _ -> Pos_infinite + | Finite x, Finite y -> + match Z.to_int y with + | exception Z.Overflow -> zero + | y -> Finite (Z.shift_right x y) end (* AlgebraicType interface for reals @@ -663,6 +697,31 @@ module Int = struct { lb = ZEuclideanType.zero ; ub = ZEuclideanType.pred i2.ub } ) u1 ) u2 + + let shl_overflows ~size n = + ZEuclideanType.(compare n (finite @@ Z.of_int size)) >= 0 + + let bvshl ~size u1 u2 = + let zero_i = { lb = ZEuclideanType.zero ; ub = ZEuclideanType.zero } in + extract ~ofs:0 ~len:size @@ + of_set_nonempty @@ + map_to_set (fun i2 -> + if shl_overflows ~size i2.lb then + (* if i2.lb >= sz, the result is always zero + must not call ZEuclideanType.shift_left or we will likely OOM *) + interval_set zero_i + else + trisection_map_to_set ZEuclideanType.zero u1 + (fun _ -> invalid_arg "bvshl: negative argument") + (fun () -> interval_set zero_i) + (approx_map_inc_to_set + (fun lb -> ZEuclideanType.shift_left lb i2.lb) + (fun ub -> ZEuclideanType.shift_left ~size ub i2.ub)) + ) u2 + + let lshr u1 u2 = + of_set_nonempty @@ + map2_mon_to_set ZEuclideanType.shift_right Inc u1 Dec u2 end module Legacy = struct diff --git a/src/lib/reasoners/intervals.mli b/src/lib/reasoners/intervals.mli index 65d7fa3250..a501188155 100644 --- a/src/lib/reasoners/intervals.mli +++ b/src/lib/reasoners/intervals.mli @@ -88,6 +88,19 @@ module Int : sig theory, i.e. where [bvurem n 0] is [n]. [s] and [t] must be within the [0, 2^sz - 1] range. *) + + val bvshl : size:int -> t -> t -> t + (** [shl sz s t] computes an overapproximation of the left shift [s lsl t], + truncating the result to [sz] bits. + + [s] and [t] must only contain non-negative integers. *) + + val lshr : t -> t -> t + (** [lshr s t] computes an approximation of the logical right shift [s lsr t]. + + Note that the result of logical right shift is independent of bit width. + + [s] and [t] must only contain non-negative integers. *) end module Legacy : sig diff --git a/src/lib/structures/expr.ml b/src/lib/structures/expr.ml index 26885ad325..e0d1dc3b79 100644 --- a/src/lib/structures/expr.ml +++ b/src/lib/structures/expr.ml @@ -3102,8 +3102,8 @@ module BV = struct (bvneg u) (* Shift operations *) - let bvshl s t = int2bv (size2 s t) Ints.(bv2nat s * (~$2 ** bv2nat t)) - let bvlshr s t = int2bv (size2 s t) Ints.(bv2nat s / (~$2 ** bv2nat t)) + let bvshl s t = mk_term (Op BVshl) [s; t] (type_info s) + let bvlshr s t = mk_term (Op BVlshr) [s; t] (type_info s) let bvashr s t = let m = size2 s t in ite (is (extract (m - 1) (m - 1) s) 0) diff --git a/src/lib/structures/symbols.ml b/src/lib/structures/symbols.ml index 8cec9761d6..0d0f7f88f3 100644 --- a/src/lib/structures/symbols.ml +++ b/src/lib/structures/symbols.ml @@ -48,6 +48,7 @@ type operator = | Extract of int * int (* lower bound * upper bound *) | BVnot | BVand | BVor | BVxor | BVadd | BVsub | BVmul | BVudiv | BVurem + | BVshl | BVlshr | Int2BV of int | BV2Nat (* FP *) | Float @@ -199,6 +200,7 @@ let compare_operators op1 op2 = | Integer_log2 | Pow | Integer_round | BVnot | BVand | BVor | BVxor | BVadd | BVsub | BVmul | BVudiv | BVurem + | BVshl | BVlshr | Int2BV _ | BV2Nat | Not_theory_constant | Is_theory_constant | Linear_dependency | Constr _ | Destruct _ | Tite) -> assert false @@ -363,6 +365,8 @@ module AEPrinter = struct | BVmul -> Fmt.pf ppf "bvmul" | BVudiv -> Fmt.pf ppf "bvudiv" | BVurem -> Fmt.pf ppf "bvurem" + | BVshl -> Fmt.pf ppf "bvshl" + | BVlshr -> Fmt.pf ppf "bvlshr" (* ArraysEx theory *) | Get -> Fmt.pf ppf "get" @@ -469,6 +473,8 @@ module SmtPrinter = struct | BVmul -> Fmt.pf ppf "bvmul" | BVudiv -> Fmt.pf ppf "bvudiv" | BVurem -> Fmt.pf ppf "bvurem" + | BVshl -> Fmt.pf ppf "bvshl" + | BVlshr -> Fmt.pf ppf "bvlshr" (* ArraysEx theory *) | Get -> Fmt.pf ppf "select" diff --git a/src/lib/structures/symbols.mli b/src/lib/structures/symbols.mli index 2a533d24ec..a9c08b25a4 100644 --- a/src/lib/structures/symbols.mli +++ b/src/lib/structures/symbols.mli @@ -48,6 +48,7 @@ type operator = | Extract of int * int (* lower bound * upper bound *) | BVnot | BVand | BVor | BVxor | BVadd | BVsub | BVmul | BVudiv | BVurem + | BVshl | BVlshr | Int2BV of int | BV2Nat (* FP *) | Float diff --git a/tests/bitvec_tests.ml b/tests/bitvec_tests.ml index 013f6ac184..9a5baf1274 100644 --- a/tests/bitvec_tests.ml +++ b/tests/bitvec_tests.ml @@ -276,6 +276,44 @@ let test_bitlist_mul sz = let () = Test.check_exn (test_bitlist_mul 3) +let zshl sz a b = + match Z.to_int b with + | b when b < sz -> Z.extract (Z.shift_left a b) 0 sz + | _ | exception Z.Overflow -> Z.zero + +let test_interval_shl sz = + test_interval_binop ~count:1_000 + sz (zshl sz) (Intervals.Int.bvshl ~size:sz) + +let () = + Test.check_exn (test_interval_shl 3) + +let test_bitlist_shl sz = + test_bitlist_binop ~count:1_000 + sz (zshl sz) Bitlist.shl + +let () = + Test.check_exn (test_bitlist_shl 3) + +let zlshr a b = + match Z.to_int b with + | b -> Z.shift_right a b + | exception Z.Overflow -> Z.zero + +let test_interval_lshr sz = + test_interval_binop ~count:1_000 + sz zlshr Intervals.Int.lshr + +let () = + Test.check_exn (test_interval_lshr 3) + +let test_bitlist_lshr sz = + test_bitlist_binop ~count:1_000 + sz zlshr Bitlist.lshr + +let () = + Test.check_exn (test_bitlist_lshr 3) + let zudiv sz a b = if Z.equal b Z.zero then Z.extract Z.minus_one 0 sz