Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup(CP): Simplify constraint handling #1040

Merged
merged 3 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 64 additions & 60 deletions src/lib/reasoners/bitv_rel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,18 @@ end = struct
let subst ex rr nrr t =
match MX.find rr t.bitlists with
| bl ->
(* Note: even if [rr] had changed its domain, we don't need to keep that
information because if the constraints that used to apply to [rr] were
not already valid, they will be marked as fresh in the [Constraints.t]
after substitution there and propagated already. *)
(* The substitution code for constraints requires that we properly update
the [changed] field here: if the domain of [rr] has changed,
constraints that applied to [rr] will apply to [nrr] after
substitution and must be propagated again. *)
let changed =
if SX.mem rr t.changed then
SX.add nrr t.changed
else
t.changed
in
let t =
{ changed = SX.remove rr t.changed
{ changed = SX.remove rr changed
; bitlists = MX.remove rr t.bitlists
}
in
Expand All @@ -245,19 +251,25 @@ end = struct
end

module Constraint : sig
include Rel_utils.Constraint with type domain = Domains.t
include Rel_utils.Constraint

val bvand : X.r -> X.r -> X.r -> t
(** [bvand x y z] is the constraint [x = y & z] *)

val bvand : ex:Ex.t -> X.r -> X.r -> X.r -> t
(** [bvand ~ex x y z] is the constraint [x = y & z] *)
val bvor : X.r -> X.r -> X.r -> t
(** [bvor x y z] is the constraint [x = y | z] *)

val bvor : ex:Ex.t -> X.r -> X.r -> X.r -> t
(** [bvor ~ex x y z] is the constraint [x = y | z] *)
val bvxor : X.r -> X.r -> X.r -> t
(** [bvxor x y z] is the constraint [x ^ y ^ z = 0] *)

val bvxor : ex:Ex.t -> X.r -> X.r -> X.r -> t
(** [bvxor ~ex x y z] is the constraint [x ^ y ^ z = 0] *)
val bvnot : X.r -> X.r -> t
(** [bvnot x y] is the constraint [x = not y] *)

val bvnot : ex:Ex.t -> X.r -> X.r -> t
(** [bvnot ~ex x y] is the constraint [x = not y] *)
val propagate : ex:Ex.t -> t -> Domains.t -> Domains.t
(** [propagate ~ex t dom] propagates the constraint [t] in domain [dom].

The explanation [ex] justifies that the constraint [t] applies, and must
be added to any domain that gets updated during propagation. *)
end = struct
type repr =
| Band of X.r * X.r * X.r
Expand Down Expand Up @@ -294,10 +306,10 @@ end = struct
Hashtbl.hash (2, SX.fold (fun r acc -> X.hash r :: acc) xs [])
| Bnot (x, y) -> Hashtbl.hash (2, X.hash x, X.hash y)

type tagged_repr = { repr : repr ; mutable tag : int }
type t = { repr : repr ; mutable tag : int }

module W = Weak.Make(struct
type t = tagged_repr
type nonrec t = t

let equal { repr = lhs; _ } { repr = rhs; _ } = equal_repr lhs rhs

Expand Down Expand Up @@ -355,19 +367,15 @@ end = struct
and y = X.subst rr nrr y in
Bnot (x, y)

(* The explanation justifies why the constraint holds. *)
type t = { repr : tagged_repr ; ex : Ex.t }

let pp ppf { repr; _ } = pp_repr ppf repr.repr
let pp ppf { repr; _ } = pp_repr ppf repr

let compare { repr = r1; _ } { repr = r2; _ } =
Int.compare r1.tag r2.tag
let compare { tag = t1; _ } { tag = t2; _ } = Stdlib.compare t1 t2

let subst ex rr nrr c =
{ repr = hcons @@ subst_repr rr nrr c.repr.repr ; ex = Ex.union ex c.ex }
let subst rr nrr c =
hcons @@ subst_repr rr nrr c.repr

let fold_deps f { repr; _ } acc =
match repr.repr with
let fold_args f { repr; _ } acc =
match repr with
| Band (x, y, z) | Bor (x, y, z) ->
let acc = f x acc in
let acc = f y acc in
Expand All @@ -379,16 +387,9 @@ end = struct
let acc = f y acc in
acc

let fold_leaves f c acc =
fold_deps (fun r acc ->
List.fold_left (fun acc r -> f r acc) acc (X.leaves r)
) c acc

type domain = Domains.t

let propagate { repr; ex } dom =
let propagate ~ex { repr; _ } dom =
Steps.incr CP;
match repr.repr with
match repr with
| Band (x, y, z) ->
let dx = Domains.get x dom
and dy = Domains.get y dom
Expand Down Expand Up @@ -448,39 +449,37 @@ end = struct
let dom = Domains.update ex y dom @@ Bitlist.lognot dx in
dom

let make ?(ex = Ex.empty) repr = { repr = hcons repr ; ex }

let bvand ~ex x y z = make ~ex @@ Band (x, y, z)
let bvor ~ex x y z = make ~ex @@ Bor (x, y, z)
let bvxor ~ex x y z =
let bvand x y z = hcons @@ Band (x, y, z)
let bvor x y z = hcons @@ Bor (x, y, z)
let bvxor x y z =
let xs = SX.singleton x in
let xs = if SX.mem y xs then SX.remove y xs else SX.add y xs in
let xs = if SX.mem z xs then SX.remove z xs else SX.add z xs in
make ~ex @@ Bxor xs
let bvnot ~ex x y = make ~ex @@ Bnot (x, y)
hcons @@ Bxor xs
let bvnot x y = hcons @@ Bnot (x, y)
end

module Constraints = Rel_utils.Constraints_Make(Constraint)
module Constraints = Rel_utils.Constraints_make(Constraint)

let extract_constraints bcs uf r t =
match E.term_view t with
(* BVnot is already internalized in the Shostak but we want to know about it
without needing a round-trip through Uf *)
| { f = Op BVnot; xs = [ x ] ; _ } ->
let rx, exx = Uf.find uf x in
Constraints.add bcs @@ Constraint.bvnot ~ex:exx r rx
Constraints.add ~ex:exx (Constraint.bvnot r rx) bcs
| { f = Op BVand; xs = [ x; y ]; _ } ->
let rx, exx = Uf.find uf x
and ry, exy = Uf.find uf y in
Constraints.add bcs @@ Constraint.bvand ~ex:(Ex.union exx exy) r rx ry
Constraints.add ~ex:(Ex.union exx exy) (Constraint.bvand r rx ry) bcs
| { f = Op BVor; xs = [ x; y ]; _ } ->
let rx, exx = Uf.find uf x
and ry, exy = Uf.find uf y in
Constraints.add bcs @@ Constraint.bvor ~ex:(Ex.union exx exy) r rx ry
Constraints.add ~ex:(Ex.union exx exy) (Constraint.bvor r rx ry) bcs
| { f = Op BVxor; xs = [ x; y ]; _ } ->
let rx, exx = Uf.find uf x
and ry, exy = Uf.find uf y in
Constraints.add bcs @@ Constraint.bvxor ~ex:(Ex.union exx exy) r rx ry
Constraints.add ~ex:(Ex.union exx exy) (Constraint.bvxor r rx ry) bcs
| _ -> bcs

let rec mk_eq ex lhs w z =
Expand Down Expand Up @@ -532,21 +531,26 @@ let add_eqs =

(* Propagate:

- The constraints that were never propagated since they were added (this
includes constraints that changed due to substitutions)
- The constraints that were never propagated since they were added
- The constraints involving variables whose domain changed since the last
propagation *)
propagation

Iterate until fixpoint is reached. *)
let propagate =
let rec propagate changed bcs dom =
match Domains.choose_changed dom with
| r, dom ->
propagate (SX.add r changed) bcs @@
Constraints.propagate bcs r dom
| exception Not_found -> changed, dom
match Constraints.next_pending bcs with
| { value; explanation = ex }, bcs ->
let dom = Constraint.propagate ~ex value dom in
propagate changed bcs dom
| exception Not_found ->
match Domains.choose_changed dom with
| r, dom ->
propagate (SX.add r changed) (Constraints.notify_leaf r bcs) dom
| exception Not_found ->
changed, bcs, dom
in
fun bcs dom ->
let bcs, dom = Constraints.propagate_fresh bcs dom in
let changed, dom = propagate SX.empty bcs dom in
let changed, bcs, dom = propagate SX.empty bcs dom in
SX.fold (fun r acc ->
add_eqs acc (Shostak.Bitv.embed r) (Domains.get r dom)
) changed [], bcs, dom
Expand Down Expand Up @@ -582,7 +586,7 @@ let assume env uf la =
match a, orig with
| L.Eq (rr, nrr), Subst when is_bv_r rr ->
let dom = Domains.subst ex rr nrr dom in
let bcs = Constraints.subst ex rr nrr bcs in
let bcs = Constraints.subst ~ex rr nrr bcs in
((bcs, dom), ss)
| L.Distinct (false, [rr; nrr]), _ when is_1bit rr ->
(* We don't (yet) support [distinct] in general, but we must
Expand All @@ -597,7 +601,7 @@ let assume env uf la =
let nrr, exnrr = Uf.find_r uf nrr in
let ex = Ex.union ex (Ex.union exrr exnrr) in
let bcs =
Constraints.add bcs @@ Constraint.bvnot ~ex rr nrr
Constraints.add ~ex (Constraint.bvnot rr nrr) bcs
in
((bcs, dom), ss)
| _ -> ((bcs, dom), ss)
Expand Down Expand Up @@ -651,7 +655,7 @@ let case_split env _uf ~for_model =
in
let _, candidates =
match
Constraints.fold_r (fun r acc ->
Constraints.fold_args (fun r acc ->
List.fold_left (fun acc { Bitv.bv; _ } ->
match bv with
| Bitv.Cte _ -> acc
Expand Down
Loading
Loading