Skip to content

Commit

Permalink
Add Expr.compare function (Closes #161)
Browse files Browse the repository at this point in the history
  • Loading branch information
filipeom committed Oct 11, 2024
1 parent e10a0c2 commit 429e52f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 56 deletions.
100 changes: 49 additions & 51 deletions src/ast/expr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
(* along with this program. If not, see <https://www.gnu.org/licenses/>. *)
(***************************************************************************)

open Ty

type binder =
| Forall
| Exists
Expand All @@ -34,12 +32,12 @@ and expr =
| Symbol of Symbol.t
| List of t list
| App of Symbol.t * t list
| Unop of Ty.t * unop * t
| Binop of Ty.t * binop * t * t
| Triop of Ty.t * triop * t * t * t
| Relop of Ty.t * relop * t * t
| Cvtop of Ty.t * cvtop * t
| Naryop of Ty.t * naryop * t list
| Unop of Ty.t * Ty.unop * t
| Binop of Ty.t * Ty.binop * t * t
| Triop of Ty.t * Ty.triop * t * t * t
| Relop of Ty.t * Ty.relop * t * t
| Cvtop of Ty.t * Ty.cvtop * t
| Naryop of Ty.t * Ty.naryop * t list
| Extract of t * int * int
| Concat of t * t
| Binder of binder * t list * t
Expand Down Expand Up @@ -124,10 +122,10 @@ module Set = PatriciaTree.MakeHashconsedSet (Key) ()

let make (e : expr) = Hc.hashcons e [@@inline]

let ( @: ) e _ = make e

let view (hte : t) : expr = hte.node [@@inline]

let compare (hte1 : t) (hte2 : t) = compare hte1.tag hte2.tag [@@inline]

let symbol s = make (Symbol s)

let is_num (e : t) = match view e with Val (Num _) -> true | _ -> false
Expand Down Expand Up @@ -245,18 +243,20 @@ module Pp = struct
(Fmt.list ~sep:Fmt.comma pp)
v
| Unop (ty, op, e) ->
Fmt.pf fmt "@[<hov 1>(%a.%a@ %a)@]" Ty.pp ty pp_unop op pp e
Fmt.pf fmt "@[<hov 1>(%a.%a@ %a)@]" Ty.pp ty Ty.pp_unop op pp e
| Binop (ty, op, e1, e2) ->
Fmt.pf fmt "@[<hov 1>(%a.%a@ %a@ %a)@]" Ty.pp ty pp_binop op pp e1 pp e2
Fmt.pf fmt "@[<hov 1>(%a.%a@ %a@ %a)@]" Ty.pp ty Ty.pp_binop op pp e1 pp
e2
| Triop (ty, op, e1, e2, e3) ->
Fmt.pf fmt "@[<hov 1>(%a.%a@ %a@ %a@ %a)@]" Ty.pp ty pp_triop op pp e1 pp
e2 pp e3
Fmt.pf fmt "@[<hov 1>(%a.%a@ %a@ %a@ %a)@]" Ty.pp ty Ty.pp_triop op pp e1
pp e2 pp e3
| Relop (ty, op, e1, e2) ->
Fmt.pf fmt "@[<hov 1>(%a.%a@ %a@ %a)@]" Ty.pp ty pp_relop op pp e1 pp e2
Fmt.pf fmt "@[<hov 1>(%a.%a@ %a@ %a)@]" Ty.pp ty Ty.pp_relop op pp e1 pp
e2
| Cvtop (ty, op, e) ->
Fmt.pf fmt "@[<hov 1>(%a.%a@ %a)@]" Ty.pp ty pp_cvtop op pp e
Fmt.pf fmt "@[<hov 1>(%a.%a@ %a)@]" Ty.pp ty Ty.pp_cvtop op pp e
| Naryop (ty, op, es) ->
Fmt.pf fmt "@[<hov 1>(%a.%a@ (%a))@]" Ty.pp ty pp_naryop op
Fmt.pf fmt "@[<hov 1>(%a.%a@ (%a))@]" Ty.pp ty Ty.pp_naryop op
(Fmt.list ~sep:Fmt.comma pp)
es
| Extract (e, h, l) ->
Expand Down Expand Up @@ -303,12 +303,11 @@ let app symbol args = make (App (symbol, args))

let let_in vars expr = make (Binder (Let_in, vars, expr))

let unop' (ty : Ty.t) (op : unop) (hte : t) : t = make (Unop (ty, op, hte))
[@@inline]
let unop' ty op hte = make (Unop (ty, op, hte)) [@@inline]

let unop (ty : Ty.t) (op : unop) (hte : t) : t =
let unop ty op hte =
match (op, view hte) with
| (Regexp_loop _ | Regexp_star), _ -> unop' ty op hte
| Ty.(Regexp_loop _ | Regexp_star), _ -> unop' ty op hte
| _, Val v -> value (Eval.unop ty op v)
| Not, Unop (_, Not, hte') -> hte'
| Neg, Unop (_, Neg, hte') -> hte'
Expand All @@ -319,13 +318,11 @@ let unop (ty : Ty.t) (op : unop) (hte : t) : t =
| Length, List es -> value (Int (List.length es))
| _ -> unop' ty op hte

let binop' (ty : Ty.t) (op : binop) (hte1 : t) (hte2 : t) : t =
make (Binop (ty, op, hte1, hte2))
[@@inline]
let binop' ty op hte1 hte2 = make (Binop (ty, op, hte1, hte2)) [@@inline]

let rec binop ty (op : binop) (hte1 : t) (hte2 : t) : t =
let rec binop ty op hte1 hte2 =
match (op, view hte1, view hte2) with
| (String_in_re | Regexp_range), _, _ -> binop' ty op hte1 hte2
| Ty.(String_in_re | Regexp_range), _, _ -> binop' ty op hte1 hte2
| op, Val v1, Val v2 -> value (Eval.binop ty op v1 v2)
| Sub, Ptr { base = b1; offset = os1 }, Ptr { base = b2; offset = os2 } ->
if Int32.equal b1 b2 then binop ty Sub os1 os2 else binop' ty op hte1 hte2
Expand Down Expand Up @@ -369,25 +366,21 @@ let rec binop ty (op : binop) (hte1 : t) (hte2 : t) : t =
| List_append, List l0, List l1 -> make (List (l0 @ l1))
| _ -> binop' ty op hte1 hte2

let triop' (ty : Ty.t) (op : triop) (e1 : t) (e2 : t) (e3 : t) : t =
make (Triop (ty, op, e1, e2, e3))
[@@inline]
let triop' ty op e1 e2 e3 = make (Triop (ty, op, e1, e2, e3)) [@@inline]

let triop ty (op : triop) (e1 : t) (e2 : t) (e3 : t) : t =
let triop ty op e1 e2 e3 =
match (op, view e1, view e2, view e3) with
| Ite, Val True, _, _ -> e2
| Ty.Ite, Val True, _, _ -> e2
| Ite, Val False, _, _ -> e3
| op, Val v1, Val v2, Val v3 -> value (Eval.triop ty op v1 v2 v3)
| _ -> triop' ty op e1 e2 e3

let relop' (ty : Ty.t) (op : relop) (hte1 : t) (hte2 : t) : t =
make (Relop (ty, op, hte1, hte2))
[@@inline]
let relop' ty op hte1 hte2 = make (Relop (ty, op, hte1, hte2)) [@@inline]

let rec relop ty (op : relop) (hte1 : t) (hte2 : t) : t =
let rec relop ty op hte1 hte2 =
match (op, view hte1, view hte2) with
| op, Val v1, Val v2 -> value (if Eval.relop ty op v1 v2 then True else False)
| Ne, Val (Real v), _ | Ne, _, Val (Real v) ->
| Ty.Ne, Val (Real v), _ | Ne, _, Val (Real v) ->
if Float.is_nan v || Float.is_infinite v then value True
else relop' ty op hte1 hte2
| _, Val (Real v), _ | _, _, Val (Real v) ->
Expand Down Expand Up @@ -441,21 +434,18 @@ and relop_list op l1 l2 =
| Ne, _, _ -> unop Ty_bool Not @@ relop_list Eq l1 l2
| (Lt | LtU | Gt | GtU | Le | LeU | Ge | GeU), _, _ -> assert false

let cvtop' (ty : Ty.t) (op : cvtop) (hte : t) : t = make (Cvtop (ty, op, hte))
[@@inline]
let cvtop' ty op hte = make (Cvtop (ty, op, hte)) [@@inline]

let cvtop ty (op : cvtop) (hte : t) : t =
let cvtop ty op hte =
match (op, view hte) with
| String_to_re, _ -> cvtop' ty op hte
| Ty.String_to_re, _ -> cvtop' ty op hte
| _, Val v -> value (Eval.cvtop ty op v)
| String_to_float, Cvtop (Ty_real, ToString, real) -> real
| _ -> cvtop' ty op hte

let naryop' (ty : Ty.t) (op : naryop) (es : t list) : t =
make (Naryop (ty, op, es))
[@@inline]
let naryop' ty op es = make (Naryop (ty, op, es)) [@@inline]

let naryop (ty : Ty.t) (op : naryop) (es : t list) : t =
let naryop ty op es =
if List.for_all (fun e -> match view e with Val _ -> true | _ -> false) es
then
let vs =
Expand Down Expand Up @@ -487,7 +477,7 @@ let extract (hte : t) ~(high : int) ~(low : int) : t =
| Val (Num (I64 x)) ->
let x' = nland64 (Int64.shift_right x (low * 8)) (high - low) in
value (Num (I64 x'))
| _ -> if high - low = size (ty hte) then hte else extract' hte ~high ~low
| _ -> if high - low = Ty.size (ty hte) then hte else extract' hte ~high ~low

let concat' (msb : t) (lsb : t) : t = make (Concat (msb, lsb)) [@@inline]

Expand Down Expand Up @@ -569,6 +559,8 @@ let simplify (hte : t) : t =
loop hte

module Bool = struct
open Ty

let of_val = function
| Val True -> Some true
| Val False -> Some false
Expand All @@ -582,7 +574,7 @@ module Bool = struct

let v b = to_val b [@@inline]

let not (b : t) =
let not b =
let bexpr = view b in
match of_val bexpr with
| Some b -> to_val (not b)
Expand All @@ -591,31 +583,31 @@ module Bool = struct
| Unop (Ty_bool, Not, cond) -> cond
| _ -> unop Ty_bool Not b )

let equal (b1 : t) (b2 : t) =
let equal b1 b2 =
match (view b1, view b2) with
| Val True, Val True | Val False, Val False -> true_
| _ -> relop Ty_bool Eq b1 b2

let distinct (b1 : t) (b2 : t) =
let distinct b1 b2 =
match (view b1, view b2) with
| Val True, Val False | Val False, Val True -> true_
| _ -> relop Ty_bool Ne b1 b2

let and_ (b1 : t) (b2 : t) =
let and_ b1 b2 =
match (of_val (view b1), of_val (view b2)) with
| Some true, _ -> b2
| _, Some true -> b1
| Some false, _ | _, Some false -> false_
| _ -> binop Ty_bool And b1 b2

let or_ (b1 : t) (b2 : t) =
let or_ b1 b2 =
match (of_val (view b1), of_val (view b2)) with
| Some false, _ -> b2
| _, Some false -> b1
| Some true, _ | _, Some true -> true_
| _ -> binop Ty_bool Or b1 b2

let ite (c : t) (r1 : t) (r2 : t) = triop Ty_bool Ite c r1 r2
let ite c r1 r2 = triop Ty_bool Ite c r1 r2
end

module Make (T : sig
Expand All @@ -626,6 +618,8 @@ module Make (T : sig
val num : elt -> Num.t
end) =
struct
open Ty

let v i = value (Num (T.num i))

let sym x = symbol Symbol.(x @: T.ty)
Expand All @@ -646,6 +640,8 @@ struct
end

module Bitv = struct
open Ty

module I8 = Make (struct
type elt = int

Expand All @@ -672,6 +668,8 @@ module Bitv = struct
end

module Fpa = struct
open Ty

module F32 = struct
include Make (struct
type elt = float
Expand Down
10 changes: 5 additions & 5 deletions src/ast/expr.mli
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ and expr =
| Concat of t * t
| Binder of binder * t list * t

val equal : t -> t -> bool
val make : expr -> t

val hash : t -> int
val view : t -> expr

val make : expr -> t
val hash : t -> int

val ( @: ) : expr -> Ty.t -> t [@@deprecated "Please use 'make' instead"]
val equal : t -> t -> bool

val view : t -> expr
val compare : t -> t -> int

(** The return type of an expression *)
val ty : t -> Ty.t
Expand Down

0 comments on commit 429e52f

Please sign in to comment.