Skip to content

Commit

Permalink
feat(BV, CP): Add propagators for bvshl and bvlshr
Browse files Browse the repository at this point in the history
This patch adds interval and bitlist propagators for the bvshl (left
shift) and bvlshr (logical right shift) in the intervals and bitlist
domains for bit-vectors.

The interval propagator for left shift needs to be written specially in
order to properly deal with overflow, but the propagator for bvlshr is
written using a generic propagator for (bi)-monotone functions.

The bitlist propagator for bvshl is required because it needs to
propagate information regarding low bits that are not tracked by
intervals. However, I am not sure that the bitlist propagator for bvlshr
is actually needed since it might be subsumed by the interval propagator
for bvlshr (and consistency constraints) entirely, and we might want to
remove it.
  • Loading branch information
bclement-ocp committed May 17, 2024
1 parent 887f497 commit 3cc4dd1
Show file tree
Hide file tree
Showing 10 changed files with 268 additions and 4 deletions.
56 changes: 56 additions & 0 deletions src/lib/reasoners/bitlist.ml
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,59 @@ let mul a b =
in
concat (unknown (sz - width mid_bits - width low_bits) Ex.empty) @@
concat mid_bits low_bits

let shl a b =
(* If the minimum value for [b] is larger than the bitwidth, the result is
zero.
Otherwise, any low zero bit in [a] is also a zero bit in the result, and
the minimum value for [b] also accounts for that many minimum zeros (e.g.
?000 shifted by at least 2 has at least 5 low zeroes).
NB: we would like to use the lower bound from the interval domain for [b]
here. *)
match Z.to_int (increase_lower_bound b Z.zero) with
| n when n < width a ->
let low_zeros = Z.trailing_zeros @@ Z.lognot @@ a.bits_clr in
if low_zeros + n >= width a then
exact (width a) Z.zero (Ex.union (explanation a) (explanation b))
else if low_zeros + n > 0 then
concat (unknown (width a - low_zeros - n) Ex.empty) @@
exact (low_zeros + n) Z.zero (Ex.union (explanation a) (explanation b))
else
unknown (width a) Ex.empty
| _ | exception Z.Overflow ->
exact (width a) Z.zero (explanation b)

let lshr a b =
(* If the minimum value for [b] is larger than the bitwidth, the result is
zero.
Otherwise, any high zero bit in [a] is also a zero bit in the result, and
the minimum value for [b] also accounts for that many minimum zeros (e.g.
000??? shifted by at least 2 is 00000?).
NB: we would like to use the lower bound from the interval domain for [b]
here. *)
match Z.to_int (increase_lower_bound b Z.zero) with
| n when n < width a ->
let sz = width a in
if Z.testbit a.bits_clr (sz - 1) then (* MSB is zero *)
let low_msb_zero = Z.numbits @@ Z.extract (Z.lognot a.bits_clr) 0 sz in
let nb_msb_zeros = sz - low_msb_zero in
assert (nb_msb_zeros > 0);
let nb_zeros = nb_msb_zeros + n in
if nb_zeros >= sz then
exact sz Z.zero (Ex.union (explanation a) (explanation b))
else
concat
(exact nb_zeros Z.zero (Ex.union (explanation a) (explanation b)))
(unknown (sz - nb_zeros) Ex.empty)
else if n > 0 then
concat
(exact n Z.zero (explanation b))
(unknown (sz - n) Ex.empty)
else
unknown sz Ex.empty
| _ | exception Z.Overflow ->
exact (width a) Z.zero (explanation b)
6 changes: 6 additions & 0 deletions src/lib/reasoners/bitlist.mli
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ val logxor : t -> t -> t
val mul : t -> t -> t
(** Multiplication. *)

val shl : t -> t -> t
(** Logical left shift. *)

val lshr : t -> t -> t
(** Logical right shift. *)

val concat : t -> t -> t
(** Bit-vector concatenation. *)

Expand Down
4 changes: 3 additions & 1 deletion src/lib/reasoners/bitv.ml
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ module Shostak(X : ALIEN) = struct
| Op (
Concat | Extract _ | BV2Nat
| BVnot | BVand | BVor | BVxor
| BVadd | BVsub | BVmul | BVudiv | BVurem)
| BVadd | BVsub | BVmul | BVudiv | BVurem
| BVshl | BVlshr)
-> true
| _ -> false

Expand Down Expand Up @@ -409,6 +410,7 @@ module Shostak(X : ALIEN) = struct
| { f = Op (
BVand | BVor | BVxor
| BVadd | BVsub | BVmul | BVudiv | BVurem
| BVshl | BVlshr
); _ } ->
X.term_embed t, []
| _ -> X.make t
Expand Down
78 changes: 77 additions & 1 deletion src/lib/reasoners/bitv_rel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,12 @@ module Constraint : sig
This uses the convention that [x % 0] is [x]. *)

val bvshl : X.r -> X.r -> X.r -> t
(** [bvshl r x y] is the constraint [r = x << y] *)

val bvlshr : X.r -> X.r -> X.r -> t
(** [bvshl r x y] is the constraint [r = x >> y] *)

val bvule : X.r -> X.r -> t

val bvugt : X.r -> X.r -> t
Expand All @@ -273,6 +279,8 @@ end = struct
| Band | Bor | Bxor
(* Arithmetic operations *)
| Badd | Bmul | Budiv | Burem
(* Shift operations *)
| Bshl | Blshr

let pp_binop ppf = function
| Band -> Fmt.pf ppf "bvand"
Expand All @@ -282,6 +290,8 @@ end = struct
| Bmul -> Fmt.pf ppf "bvmul"
| Budiv -> Fmt.pf ppf "bvudiv"
| Burem -> Fmt.pf ppf "bvurem"
| Bshl -> Fmt.pf ppf "bvshl"
| Blshr -> Fmt.pf ppf "bvlshr"

let equal_binop op1 op2 =
match op1, op2 with
Expand All @@ -304,12 +314,18 @@ end = struct
| Budiv, _ | _, Budiv -> false

| Burem, Burem -> true
| Burem, _ | _, Burem -> false

| Bshl, Bshl -> true
| Bshl, _ | _, Bshl -> false

| Blshr, Blshr -> true

let hash_binop : binop -> int = Hashtbl.hash

let is_commutative = function
| Band | Bor | Bxor | Badd | Bmul -> true
| Budiv | Burem -> false
| Budiv | Burem | Bshl | Blshr -> false

let propagate_binop ~ex dx op dy dz =
let open Domains.Ephemeral in
Expand Down Expand Up @@ -343,6 +359,12 @@ end = struct
(* TODO: full adder propagation *)
()

| Bshl -> (* Only forward propagation for now *)
update ~ex dx (Bitlist.shl !!dy !!dz)

| Blshr -> (* Only forward propagation for now *)
update ~ex dx (Bitlist.lshr !!dy !!dz)

| Bmul -> (* Only forward propagation for now *)
update ~ex dx (Bitlist.mul !!dy !!dz)

Expand All @@ -361,6 +383,12 @@ end = struct
update ~ex dy @@ norm @@ Intervals.Int.sub !!dr !!dx;
update ~ex dx @@ norm @@ Intervals.Int.sub !!dr !!dy

| Bshl -> (* Only forward propagation for now *)
update ~ex dr @@ Intervals.Int.bvshl ~size:sz !!dx !!dy

| Blshr -> (* Only forward propagation for now *)
update ~ex dr @@ Intervals.Int.lshr !!dx !!dy

| Bmul -> (* Only forward propagation for now *)
update ~ex dr @@ norm @@ Intervals.Int.mul !!dx !!dy

Expand Down Expand Up @@ -574,6 +602,8 @@ end = struct
let bvmul = cbinop Bmul
let bvudiv = cbinop Budiv
let bvurem = cbinop Burem
let bvshl = cbinop Bshl
let bvlshr = cbinop Blshr

let crel r = hcons @@ Crel r

Expand Down Expand Up @@ -729,6 +759,27 @@ end = struct
) else
false

(* Add the constraint: r = x >> c *)
let add_lshr_const acts r x c =
let sz = bitwidth r in
match Z.to_int c with
| 0 -> add_eq acts r x
| n when n < sz ->
assert (n > 0);
let r_bitv = Shostak.Bitv.embed r in
let low_bits =
Shostak.Bitv.is_mine @@
Bitv.extract sz n (sz - 1) (Shostak.Bitv.embed x)
in
add_eq acts
(Shostak.Bitv.is_mine @@ Bitv.extract sz 0 (sz - 1 - n) r_bitv)
low_bits;
add_eq_const acts
(Shostak.Bitv.is_mine @@ Bitv.extract sz (sz - n) (sz - 1) r_bitv)
Z.zero
| _ | exception Z.Overflow ->
add_eq_const acts r Z.zero

(* Ground evaluation rules for binary operators. *)
let eval_binop op ty x y =
match op with
Expand All @@ -747,6 +798,18 @@ end = struct
cast ty x
else
cast ty @@ Z.rem x y
| Bshl -> (
match ty, Z.to_int y with
| Tbitv sz, y when y < sz ->
cast ty @@ Z.shift_left x y
| _ | exception Z.Overflow -> cast ty Z.zero
)
| Blshr -> (
match ty, Z.to_int y with
| Tbitv sz, y when y < sz ->
cast ty @@ Z.shift_right x y
| _ | exception Z.Overflow -> cast ty Z.zero
)

(* Constant simplification rules for binary operators.
Expand Down Expand Up @@ -793,6 +856,17 @@ end = struct

| Budiv | Burem -> false

(* shifts becomes a simple extraction when we know the right-hand side *)
| Bshl when X.is_constant y ->
add_shl_const acts r x (value y);
true
| Bshl -> false

| Blshr when X.is_constant y ->
add_lshr_const acts r x (value y);
true
| Blshr -> false

(* Algebraic rewrite rules for binary operators.
Rules based on constant simplifications are in [rw_binop_const]. *)
Expand Down Expand Up @@ -864,6 +938,8 @@ let extract_binop =
| BVmul -> Some bvmul
| BVudiv -> Some bvudiv
| BVurem -> Some bvurem
| BVshl -> Some bvshl
| BVlshr -> Some bvlshr
| _ -> None

let extract_constraints bcs uf r t =
Expand Down
66 changes: 66 additions & 0 deletions src/lib/reasoners/intervals.ml
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,47 @@ module ZEuclideanType = struct
| Neg_infinite -> Pos_infinite
| Pos_infinite -> Neg_infinite
| Finite n -> Finite (Z.lognot n)

(* Values larger than [max_int] are treated as +oo *)
let shift_left ?(max_int = max_int) x y =
match y with
| Neg_infinite ->
Fmt.invalid_arg "shl: must shift by nonnegative amount"
| Pos_infinite -> Pos_infinite
| Finite y when Z.sign y < 0 ->
Fmt.invalid_arg "shl: must shift by nonnegative amount"
| Finite y ->
match Z.to_int y with
| exception Z.Overflow -> Pos_infinite
| y ->
if y <= max_int then
match x with
| Neg_infinite -> Neg_infinite
| Pos_infinite -> Pos_infinite
| Finite x -> Finite (Z.shift_left x y)
else Pos_infinite

let shift_right x y =
match y with
| Neg_infinite ->
invalid_arg "shift_right: must shift by nonnegative amount"
| Finite y when Z.sign y < 0 ->
invalid_arg "shift_right: must shift by nonnegative amount"
| Pos_infinite -> (
match x with
| Pos_infinite -> invalid_arg "shift_right: undefined limit"
| _ -> zero
)
| Finite y ->
match x with
| Neg_infinite -> Neg_infinite
| Pos_infinite -> Pos_infinite
| Finite x ->
match Z.to_int y with
| exception Z.Overflow ->
(* y > max_int -> x >> y = 0 since numbits x <= max_int *)
zero
| y -> Finite (Z.shift_right x y)
end

(* AlgebraicType interface for reals
Expand Down Expand Up @@ -663,6 +704,31 @@ module Int = struct
{ lb = ZEuclideanType.zero ; ub = ZEuclideanType.pred i2.ub }
) u1
) u2

let bvshl ~size u1 u2 =
assert (size > 0);
(* Values higher than [max_int] ultimately map to [0] *)
let max_int = size - 1 in
let zero_i = { lb = ZEuclideanType.zero ; ub = ZEuclideanType.zero } in
extract ~ofs:0 ~len:size @@
of_set_nonempty @@
map_to_set (fun i2 ->
assert (ZEuclideanType.sign i2.lb >= 0);
if ZEuclideanType.(compare i2.lb (finite @@ Z.of_int max_int)) > 0 then
(* if i2.lb > max_int, the result is always zero
must not call ZEuclideanType.shift_left or we will likely OOM *)
interval_set zero_i
else
(* equivalent to multiplication by a positive constant *)
approx_map_inc_to_set
(fun lb -> ZEuclideanType.shift_left lb i2.lb)
(fun ub -> ZEuclideanType.shift_left ~max_int ub i2.ub)
u1
) u2

let lshr u1 u2 =
of_set_nonempty @@
map2_mon_to_set ZEuclideanType.shift_right Inc u1 Dec u2
end

module Legacy = struct
Expand Down
13 changes: 13 additions & 0 deletions src/lib/reasoners/intervals.mli
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@ module Int : sig
theory, i.e. where [bvurem n 0] is [n].
[s] and [t] must be within the [0, 2^sz - 1] range. *)

val bvshl : size:int -> t -> t -> t
(** [shl sz s t] computes an overapproximation of the left shift [s lsl t],
truncating the result to [sz] bits.
[s] and [t] must only contain non-negative integers. *)

val lshr : t -> t -> t
(** [lshr s t] computes an approximation of the logical right shift [s lsr t].
Note that the result of logical right shift is independent of bit width.
[s] and [t] must only contain non-negative integers. *)
end

module Legacy : sig
Expand Down
4 changes: 2 additions & 2 deletions src/lib/structures/expr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3156,8 +3156,8 @@ module BV = struct
(bvneg u)

(* Shift operations *)
let bvshl s t = int2bv (size2 s t) Ints.(bv2nat s * (~$2 ** bv2nat t))
let bvlshr s t = int2bv (size2 s t) Ints.(bv2nat s / (~$2 ** bv2nat t))
let bvshl s t = mk_term (Op BVshl) [s; t] (type_info s)
let bvlshr s t = mk_term (Op BVlshr) [s; t] (type_info s)
let bvashr s t =
let m = size2 s t in
ite (is (extract (m - 1) (m - 1) s) 0)
Expand Down
6 changes: 6 additions & 0 deletions src/lib/structures/symbols.ml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type operator =
| Extract of int * int (* lower bound * upper bound *)
| BVnot | BVand | BVor | BVxor
| BVadd | BVsub | BVmul | BVudiv | BVurem
| BVshl | BVlshr
| Int2BV of int | BV2Nat
(* FP *)
| Float
Expand Down Expand Up @@ -194,6 +195,7 @@ let compare_operators op1 op2 =
| Integer_log2 | Pow | Integer_round
| BVnot | BVand | BVor | BVxor
| BVadd | BVsub | BVmul | BVudiv | BVurem
| BVshl | BVlshr
| Int2BV _ | BV2Nat
| Not_theory_constant | Is_theory_constant | Linear_dependency
| Constr _ | Destruct _ | Tite) -> assert false
Expand Down Expand Up @@ -358,6 +360,8 @@ module AEPrinter = struct
| BVmul -> Fmt.pf ppf "bvmul"
| BVudiv -> Fmt.pf ppf "bvudiv"
| BVurem -> Fmt.pf ppf "bvurem"
| BVshl -> Fmt.pf ppf "bvshl"
| BVlshr -> Fmt.pf ppf "bvlshr"

(* ArraysEx theory *)
| Get -> Fmt.pf ppf "get"
Expand Down Expand Up @@ -464,6 +468,8 @@ module SmtPrinter = struct
| BVmul -> Fmt.pf ppf "bvmul"
| BVudiv -> Fmt.pf ppf "bvudiv"
| BVurem -> Fmt.pf ppf "bvurem"
| BVshl -> Fmt.pf ppf "bvshl"
| BVlshr -> Fmt.pf ppf "bvlshr"

(* ArraysEx theory *)
| Get -> Fmt.pf ppf "select"
Expand Down
Loading

0 comments on commit 3cc4dd1

Please sign in to comment.