Skip to content

optimize context readback #390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 126 additions & 28 deletions src/coq_elpi_HOAS.ml
Original file line number Diff line number Diff line change
@@ -1114,6 +1114,8 @@ let mk_coq_context ~options state =
options;
}

let update_options ~options c = { c with options }

let push_coq_ctx_proof i e coq_ctx =
assert(coq_ctx.local = []);
let id = Context.Named.Declaration.get_id e in
@@ -1527,30 +1529,38 @@ let find_evar var csts =
| _ -> None end
| _ -> None)

let preprocess_context visible context =
type processed_context_item =
| Decl of API.Data.term * API.Data.term
| Def of API.Data.term * API.Data.term * API.Data.term

type processed_context = (int * processed_context_item) Int.Map.t

let preprocess_context visible ?(init=Int.Map.empty) context =
let select_ctx_entries visible { E.hdepth = depth; E.hsrc = t } =
let isVisibleConst t = match E.look ~depth t with E.Const i -> visible i | _ -> false in
let destConst t = match E.look ~depth t with E.Const x -> x | _ -> assert false in
match E.look ~depth t with
| E.App(c,v,[name;ty]) when c == declc && isVisibleConst v ->
Some (destConst v, depth, `Decl(name,ty))
Some (destConst v, depth, Decl(name,ty))
| E.App(c,v,[name;ty;bo]) when c == defc && isVisibleConst v ->
Some (destConst v, depth, `Def (name,ty,bo))
Some (destConst v, depth, Def (name,ty,bo))
| _ ->
debug Pp.(fun () ->
str "skip entry" ++
str(pp2string (P.term depth) t));
None
in
let ctx_hyps = CList.map_filter (select_ctx_entries visible) context in
let min_k = try 1 + (fst @@ Int.Map.max_binding init) with Not_found -> 0 in
let min_k = ref min_k in
let dbl2ctx =
List.fold_right (fun (i,d,e) m ->
if Int.Map.mem i m
then err Pp.(str "Duplicate context entry for " ++
str(pp2string (P.term d) (E.mkConst i)))
else Int.Map.add i (d,e) m)
ctx_hyps Int.Map.empty in
dbl2ctx
else (min_k := min !min_k i; Int.Map.add i (d,e) m))
ctx_hyps init in
!min_k, dbl2ctx

let rec dblset_of_canonical_ctx ~depth acc = function
| [] -> acc
@@ -1571,7 +1581,7 @@ let find_evar_decl var csts =
str(pp2string pp_cst cst));
let args = if F.Elpi.(equal raw var) then args_raw else args in
let visible_set = dblset_of_canonical_ctx ~depth Int.Set.empty args in
let dbl2ctx = preprocess_context (fun x -> Int.Set.mem x visible_set) context in
let _,dbl2ctx = preprocess_context (fun x -> Int.Set.mem x visible_set) context in
Some (dbl2ctx, raw, r, (depth,ty), cst)
| _ -> None end
| _ -> None)
@@ -1644,25 +1654,26 @@ let in_coq_poly_gref ~depth ~origin ~failsafe s t i =
str "The term " ++ str (pp2string (P.term depth) origin) ++
str " cannot be represented in Coq since its gref or univ-instance part is illformed"

let rec of_elpi_ctx ~calldepth syntactic_constraints depth dbl2ctx state initial_coq_ctx =
let rec of_elpi_ctx ~calldepth syntactic_constraints depth dbl2ctx state ?(initial_depth=0) initial_coq_ctx =

let aux coq_ctx depth state t =
lp2constr ~calldepth syntactic_constraints coq_ctx ~depth state t in

let of_elpi_ctx_entry dbl coq_ctx ~depth e state =
match e with
| `Decl(name,ty) ->
| Decl(name,ty) ->
let id = in_coq_fresh_annot_id ~depth ~coq_ctx dbl name in
let state, ty, gls = aux coq_ctx depth state ty in
state, Context.Named.Declaration.LocalAssum(id,ty), gls
| `Def(name,ty,bo) ->
| Def(name,ty,bo) ->
let id = in_coq_fresh_annot_id ~depth ~coq_ctx dbl name in
let state, ty, gl1 = aux coq_ctx depth state ty in
let state, bo, gl2 = aux coq_ctx depth state bo in
state, Context.Named.Declaration.LocalDef(id,bo,ty), gl1 @ gl2
in

let rec ctx_entries coq_ctx state gls i =
(*Printf.eprintf "init %d curr %d\n%!" initial_depth i;*)
if i = depth then state, coq_ctx, List.(concat (rev gls))
else (* context entry for the i-th variable *)
if not (Int.Map.mem i dbl2ctx)
@@ -1675,7 +1686,7 @@ let rec of_elpi_ctx ~calldepth syntactic_constraints depth dbl2ctx state initial
let coq_ctx = push_coq_ctx_proof i e coq_ctx in
ctx_entries coq_ctx state (gl1 :: gls) (i+1)
in
ctx_entries initial_coq_ctx state [] 0
ctx_entries initial_coq_ctx state [] initial_depth

(* ***************************************************************** *)
(* <-- depth --> *)
@@ -1747,7 +1758,7 @@ and lp2constr ~calldepth syntactic_constraints coq_ctx ~depth state ?(on_ty=fals
state, hole, []
else
err Pp.(hov 0 (str"Bound variable " ++ str (E.Constants.show n) ++
str" not found in the Coq context:" ++ cut () ++
str" not found at depth " ++ int depth ++ str " in the Coq context:" ++ cut () ++
pr_coq_ctx coq_ctx (get_sigma state) ++ cut () ++
str"Did you forget to load some hypotheses with => ?"))
else
@@ -3293,25 +3304,112 @@ module CtxReadbackCache = Ephemeron.K1.Make(struct
let equal = (==)
let hash = Hashtbl.hash
end)
let ctx_cache_lp2c = CtxReadbackCache.create 1

type ctx_cache_lp2c = {
hashtbl : (processed_context * full coq_context * Environ.env) CtxReadbackCache.t;
last_depth : int ref;
}
let ctx_cache_lp2c = { hashtbl = CtxReadbackCache.create 1 ; last_depth = ref 0 }

let ctx_cache_lp2c_set depth hyps processed_ctx coq_ctx =
(*Printf.eprintf "set cache at %d:\n%!" depth;
let pp (i,j) = Printf.eprintf "%d -> %s\n%!" i (Names.Id.to_string j) in
List.iter pp (Int.Map.bindings coq_ctx.db2name);*)
let { hashtbl; last_depth } = ctx_cache_lp2c in
CtxReadbackCache.reset hashtbl;
CtxReadbackCache.add hashtbl hyps (processed_ctx,coq_ctx,Global.env ());
last_depth := depth

let ctx_cache_lp2c_reset () =
(*Printf.eprintf "reset cache\n%!";*)
let { hashtbl; last_depth } = ctx_cache_lp2c in
CtxReadbackCache.reset hashtbl;
last_depth := 0

type cache_search_result =
| NotFound
| OldEntry of int * API.Data.hyp list * processed_context * full coq_context
| Found of processed_context * full coq_context

let ctx_cache_lp2c_find hyps =
match CtxReadbackCache.find ctx_cache_lp2c.hashtbl hyps with
| (pc,c,e) when e == Global.env () -> Found(pc,c)
| _ -> NotFound
| exception Not_found -> NotFound

let rec old_cache depth new_hyps_rev hyps =
match hyps with
| [] -> NotFound
| h :: hyps ->
match ctx_cache_lp2c_find hyps with
| Found(pc,c) ->
OldEntry (depth,List.rev (h::new_hyps_rev),pc,c)
| _ ->
match E.of_hyps [h] with
| [{ E.hdepth }] when hdepth >= depth ->
old_cache depth (h :: new_hyps_rev) hyps
| _ -> NotFound

let ctx_cache_lp2c_search depth hyps =
let last_depth = !(ctx_cache_lp2c.last_depth) in
if depth > last_depth then
old_cache last_depth [] hyps
else
ctx_cache_lp2c_find hyps

let get_current_env_sigma ~depth hyps constraints state =
(*Printf.eprintf "--------------------\nget_current_env_sigma called at depth %d\n%!" depth;*)
let state, _, changed, gl1 = elpi_solution_to_coq_solution constraints state in
if changed then CtxReadbackCache.reset ctx_cache_lp2c;
let state, coq_ctx, gl2 =
match CtxReadbackCache.find ctx_cache_lp2c hyps with
| (c,e,d) when d == depth && e == Global.env () -> state, c, []
| _ ->
of_elpi_ctx ~calldepth:depth constraints depth
(preprocess_context (fun _ -> true) (E.of_hyps hyps))
state (mk_coq_context ~options:(get_options ~depth hyps state) state)
| exception Not_found ->
of_elpi_ctx ~calldepth:depth constraints depth
(preprocess_context (fun _ -> true) (E.of_hyps hyps))
state (mk_coq_context ~options:(get_options ~depth hyps state) state)
if changed then ctx_cache_lp2c_reset ();
let options = get_options ~depth hyps state in
let processed_ctx, (state, coq_ctx, gl2) =
match ctx_cache_lp2c_search depth hyps with
| Found(pc,c) -> pc, (state, update_options ~options c, [])
| OldEntry (old_depth,todo,old_pc,c) ->
(*let _,pc' = preprocess_context (fun _ -> true) (E.of_hyps hyps) in*)
let min,pc = preprocess_context (fun _ -> true) ~init:old_pc (E.of_hyps todo) in
(*
if(not(Int.Map.equal (=) pc pc')) then begin
let m0 = Int.Map.bindings old_pc in
let m1 = Int.Map.bindings pc in
let m2 = Int.Map.bindings pc' in
let pp (i,(j,_)) = Printf.eprintf "%d -> %d\n%!" i j in
Printf.eprintf "old:\n%!";
List.iter pp m0;
Printf.eprintf "wrong (%d):\n%!" (List.length todo);
List.iter pp m1;
Printf.eprintf "correct:\n%!";
List.iter pp m2;
end;
Printf.eprintf "adapt cache entry from depth to: %d -> %d\n%!" min depth;
Printf.eprintf "OLD:\n%s\nADD:\n%!" (Pp.string_of_ppcmds (pr_coq_ctx c (get_sigma state)));
let m = Int.Map.bindings pc in
let pp (i,(j,_)) = Printf.eprintf "%d -> %d\n%!" i j in
List.iter pp m;*)

let state,ctx,gls = of_elpi_ctx ~calldepth:depth constraints depth pc state ~initial_depth:(min) c in
(*Printf.eprintf "ignore:\n%!";
let _,ctx',_ = of_elpi_ctx ~calldepth:depth constraints depth pc' state (mk_coq_context ~options state) in
Printf.eprintf "end:\n%!";
*)
(* if not(ctx = ctx') then begin *)
(*Printf.eprintf "UPDATED:\n%s\n%!" (Pp.string_of_ppcmds (pr_coq_ctx ctx (get_sigma state)));*)
(* Printf.eprintf "correct: %d \n%!" ctx'.proof_len;
Printf.eprintf "%s\n%!" (Pp.string_of_ppcmds (pr_coq_ctx ctx' (get_sigma state)));
end;
*)
pc, (state, ctx, gls)
| NotFound ->
let _,pc = preprocess_context (fun _ -> true) (E.of_hyps hyps) in
(*
let m2 = Int.Map.bindings pc in
let pp (i,(j,_)) = Printf.eprintf "%d -> %d\n%!" i j in
Printf.eprintf "new:\n%!";
List.iter pp m2;*)
pc,
of_elpi_ctx ~calldepth:depth constraints depth pc state (mk_coq_context ~options state)
in
CtxReadbackCache.reset ctx_cache_lp2c;
CtxReadbackCache.add ctx_cache_lp2c hyps (coq_ctx,Global.env (),depth);
ctx_cache_lp2c_set depth hyps processed_ctx coq_ctx;
state, coq_ctx, get_sigma state, gl1 @ gl2
;;

@@ -3331,7 +3429,7 @@ let lp2goal ~depth hyps syntactic_constraints state t =
let visible_set = dblset_of_canonical_ctx ~depth Int.Set.empty scope in
let state, coq_ctx, gl2 =
of_elpi_ctx ~calldepth:depth syntactic_constraints depth
(preprocess_context (fun x -> Int.Set.mem x visible_set)
(snd @@ preprocess_context (fun x -> Int.Set.mem x visible_set)
(U.lp_list_to_list ~depth ctx |> List.map (fun hsrc -> { E.hdepth = depth; E.hsrc })))
state
(mk_coq_context ~options:(get_options ~depth hyps state) state) in
41 changes: 36 additions & 5 deletions tests/test_ctx_cache.v
Original file line number Diff line number Diff line change
@@ -14,11 +14,25 @@ solve (goal _ _ _ _ [str "cache", int N]) _ :- !,
loop N (coq.unify-eq {{ 0 + 0 }} {{ 0 }} ok, @pi-decl `x` {{ bool }} x\ true).
solve (goal _ _ _ _ [str "nocache", int N]) _ :- !,
loop N (@pi-decl `x` {{ bool }} x\ coq.unify-eq {{ 0 + 0 }} {{ 0 }} ok, true).
solve (goal _ _ _ _ [str "deepcache", int N]) _ :- !,
dloop N (coq.unify-eq {{ 0 + 0 }} {{ 0 }} ok, true).
solve (goal _ _ _ _ [str "nodeepcache", int N]) _ :- !,
dloop N (@pi-decl `x` {{ bool }} x\coq.unify-eq {{ 0 + 0 }} {{ 0 }} ok, true).

pred dloop i:int, i:prop.
dloop 0 _.
dloop M P :-
N is M - 1,
P,
@pi-decl `x` {{ bool }} x\ dloop N P.

}}.
Elpi Typecheck.

Elpi Command say.
Elpi Accumulate lp:{{ main [str X] :- coq.say X. }}.
Elpi Export say.

Notation t :=
(
nat * nat * nat * nat * nat * nat * nat * nat * nat * nat * nat * nat *
@@ -30,6 +44,7 @@ Notation t :=
nat * nat * nat * nat * nat * nat * nat * nat * nat * nat * nat * nat *
nat * nat * nat * nat * nat * nat * nat * nat * nat * nat * nat * nat
)%type.
say "Huge context stup".
Time Goal
forall (x1 x2 x3 x4 x5 x6 x7 x8 x9 x10 : t),
forall (x1 x2 x3 x4 x5 x6 x7 x8 x9 x10 : t),
@@ -38,10 +53,26 @@ forall (x1 x2 x3 x4 x5 x6 x7 x8 x9 x10 : t),
forall (x1 x2 x3 x4 x5 x6 x7 x8 x9 x10 : t),
True.
intros.
Optimize Heap.
Time elpi perf nocache 3000.
Optimize Heap.
Time elpi perf cache 3000.

say "----------------------".
say "## bench cache (same ctx)".
say "----------------------".
say "no cache, 3K".
Optimize Heap. Time elpi perf nocache 3000.
say " cache, 3K".
Optimize Heap. Time elpi perf cache 3000.
say "----------------------".
say "## bench incremental cache (growing ctx)".
say "----------------------".
say "no cache, 1K".
Optimize Heap. Time elpi perf nodeepcache 1000.
say "no cache, 2K".
Optimize Heap. Time elpi perf nodeepcache 2000.
say " cache, 3K".
Optimize Heap. Time elpi perf deepcache 3000.
say " cache, 6K".
Optimize Heap. Time elpi perf deepcache 6000.
say " cache, 9K".
Optimize Heap. Time elpi perf deepcache 9000.
say "----------------------".
trivial.
Qed.