Skip to content

Commit 47ec4ac

Browse files
committed
Start caching errors in Memo
Signed-off-by: Andrey Mokhov <amokhov@janestreet.com>
1 parent 83707b4 commit 47ec4ac

File tree

2 files changed

+73
-28
lines changed

2 files changed

+73
-28
lines changed

src/memo/memo.ml

+31-17
Original file line numberDiff line numberDiff line change
@@ -718,15 +718,19 @@ module Cached_value = struct
718718
t.deps <- capture_dep_values ~deps_rev;
719719
t
720720

721-
let value_changed (type o) (node : (_, o) Dep_node.t) prev_output curr_output
722-
=
723-
match (prev_output, curr_output) with
724-
| (Value.Error _ | Cancelled _), _ -> true
725-
| _, (Value.Error _ | Cancelled _) -> true
726-
| Ok prev_output, Ok curr_output -> (
721+
let value_changed (node : _ Dep_node.t) prev_value cur_value =
722+
match ((prev_value : _ Value.t), (cur_value : _ Value.t)) with
723+
| Cancelled _, _
724+
| _, Cancelled _
725+
| Error _, Ok _
726+
| Ok _, Error _ ->
727+
true
728+
| Ok prev_value, Ok cur_value -> (
727729
match node.without_state.spec.allow_cutoff with
728-
| Yes equal -> not (equal prev_output curr_output)
730+
| Yes equal -> not (equal prev_value cur_value)
729731
| No -> true)
732+
| Error prev_error, Error cur_error ->
733+
not (Exn_set.equal prev_error cur_error)
730734
end
731735

732736
(* Add a dependency on the [dep_node] from the caller, if there is one. Returns
@@ -832,7 +836,7 @@ let dep_node (t : (_, _) t) input =
832836
833837
- [Unchanged]: all the dependencies of the current node are up to date and we
834838
can therefore skip recomputing the node and can reuse the value computed in
835-
the previuos run.
839+
the previous run.
836840
837841
- [Changed]: one of the dependencies has changed since the previous run and
838842
the current node should therefore be recomputed.
@@ -873,13 +877,15 @@ end = struct
873877
(* Dependencies of cancelled computations are not accurate, so we can't
874878
use [deps_changed] in this case. *)
875879
Fiber.return (Error Cache_lookup.Failure.Not_found)
876-
| Error _ ->
877-
(* We always recompute errors, so there is no point in checking if any
878-
of their dependencies changed. In principle, we could introduce
879-
"persistent errors" that are recomputed only when their dependencies
880-
have changed. *)
881-
Fiber.return (Error Cache_lookup.Failure.Not_found)
882-
| Ok _ -> (
880+
| Ok _
881+
| Error _ -> (
882+
(* We cache errors just like normal values. We assume that all [Memo]
883+
computations are deterministic, which means if we rerun a computation
884+
that previously led to raising a set of errors, we expect to get the
885+
same set of errors back and we might as well skip the unnecessary
886+
work. The downside is that if a computation is non-deterministic,
887+
there is no way to force rerunning it, apart from changing some of
888+
its dependencies. *)
883889
let+ deps_changed =
884890
let rec go deps =
885891
match deps with
@@ -891,8 +897,16 @@ end = struct
891897
is up to date. If not, we must recompute [last_cached_value]. *)
892898
let* restore_result = consider_and_restore_from_cache dep in
893899
match restore_result with
894-
| Ok cached_value -> (
895-
match Value_id.equal cached_value.id v_id with
900+
| Ok cached_value_of_dep -> (
901+
(* Here we know that [dep] can be restored from the cache, so
902+
how can [v_id] be different from [cached_value_of_dep.id]?
903+
Good question! This can happen if [cached_value]'s node was
904+
skipped in the previous run (because it was unreachable),
905+
while [dep] wasn't skipped and its value changed. In the
906+
current run, [cached_value] is therefore stale. We learn
907+
this when we see that the [cached_value_of_dep] is not as
908+
recorded when computing [cached_value]. *)
909+
match Value_id.equal cached_value_of_dep.id v_id with
896910
| true -> go deps
897911
| false -> Fiber.return Changed_or_not.Changed)
898912
| Error (Cancelled { dependency_cycle }) ->

test/expect-tests/memo/memoize_tests.ml

+42-11
Original file line numberDiff line numberDiff line change
@@ -972,34 +972,34 @@ let%expect_test "dynamic cycles with non-uniform cutoff structure" =
972972
evaluate_and_print summit_no_cutoff 0;
973973
[%expect
974974
{|
975-
Started evaluating the summit with input 0
976-
Started evaluating incrementing_chain_4_yes_cutoff
977-
Started evaluating incrementing_chain_3_no_cutoff
975+
Started evaluating base
976+
Evaluated base: 3
978977
Started evaluating incrementing_chain_2_yes_cutoff
979978
Started evaluating incrementing_chain_1_no_cutoff
980979
Started evaluating cycle_creator_no_cutoff
981-
Started evaluating base
982-
Evaluated base: 3
983980
Evaluated cycle_creator_no_cutoff: 3
984981
Evaluated incrementing_chain_1_no_cutoff: 4
985982
Evaluated incrementing_chain_2_yes_cutoff: 5
983+
Started evaluating incrementing_chain_4_yes_cutoff
984+
Started evaluating incrementing_chain_3_no_cutoff
986985
Evaluated incrementing_chain_3_no_cutoff: 6
987986
Evaluated incrementing_chain_4_yes_cutoff: 7
987+
Started evaluating the summit with input 0
988988
Evaluated the summit with input 0: 7
989989
f 0 = Ok 7 |}];
990990
evaluate_and_print summit_yes_cutoff 0;
991991
[%expect
992992
{|
993-
Started evaluating the summit with input 0
994-
Started evaluating incrementing_chain_4_no_cutoff
995-
Started evaluating incrementing_chain_3_yes_cutoff
996-
Started evaluating incrementing_chain_2_no_cutoff
997-
Started evaluating incrementing_chain_1_yes_cutoff
998993
Started evaluating cycle_creator_yes_cutoff
999994
Evaluated cycle_creator_yes_cutoff: 3
995+
Started evaluating incrementing_chain_1_yes_cutoff
1000996
Evaluated incrementing_chain_1_yes_cutoff: 4
997+
Started evaluating incrementing_chain_3_yes_cutoff
998+
Started evaluating incrementing_chain_2_no_cutoff
1001999
Evaluated incrementing_chain_2_no_cutoff: 5
10021000
Evaluated incrementing_chain_3_yes_cutoff: 6
1001+
Started evaluating the summit with input 0
1002+
Started evaluating incrementing_chain_4_no_cutoff
10031003
Evaluated incrementing_chain_4_no_cutoff: 7
10041004
Evaluated the summit with input 0: 7
10051005
f 0 = Ok 7 |}];
@@ -1403,7 +1403,6 @@ let%expect_test "error handling and duplicate exceptions" =
14031403
in
14041404
Fdecl.set f_impl (fun x ->
14051405
printf "Calling f %d\n" x;
1406-
14071406
match x with
14081407
| 0 -> Memo.exec forward_fail x
14091408
| 1 -> Memo.exec forward_fail2 x
@@ -1420,3 +1419,35 @@ let%expect_test "error handling and duplicate exceptions" =
14201419
Calling f 0
14211420
Error [ "(Failure 42)" ]
14221421
|}]
1422+
1423+
let%expect_test "errors are cached" =
1424+
Printexc.record_backtrace false;
1425+
let f =
1426+
Memo.create_hidden "area of a square"
1427+
~input:(module Int)
1428+
(fun x ->
1429+
printf "Started evaluating %d\n" x;
1430+
if x < 0 then failwith (sprintf "Negative input %d" x);
1431+
let res = x * x in
1432+
printf "Evaluated %d: %d\n" x res;
1433+
Memo.Build.return res)
1434+
in
1435+
evaluate_and_print f 5;
1436+
evaluate_and_print f (-5);
1437+
[%expect
1438+
{|
1439+
Started evaluating 5
1440+
Evaluated 5: 25
1441+
f 5 = Ok 25
1442+
Started evaluating -5
1443+
f -5 = Error [ { exn = "(Failure \"Negative input -5\")"; backtrace = "" } ]
1444+
|}];
1445+
evaluate_and_print f 5;
1446+
evaluate_and_print f (-5);
1447+
(* Note that we do not see any "Started evaluating" messages because both [Ok]
1448+
and [Error] results have been cached. *)
1449+
[%expect
1450+
{|
1451+
f 5 = Ok 25
1452+
f -5 = Error [ { exn = "(Failure \"Negative input -5\")"; backtrace = "" } ]
1453+
|}]

0 commit comments

Comments
 (0)