-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathl2_opt.ML
277 lines (254 loc) · 8.6 KB
/
l2_opt.ML
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
(*
* Copyright 2014, NICTA
*
* This software may be distributed and modified according to the terms of
* the BSD 2-Clause license. Note that NO WARRANTY is provided.
* See "LICENSE_BSD2.txt" for details.
*
* @TAG(NICTA_BSD)
*)
(*
* Optimise L2 fragments of code by using facts learnt earlier in the fragments
* to simplify code afterwards.
*)
structure L2Opt =
struct
(*
* Map the given simpset to tweak it for L2Opt.
*
* If "use_ugly_rules" is enabled, we will use rules that are useful for
* discharging proofs, but make the output ugly.
*)
fun map_opt_simpset use_ugly_rules =
Simplifier.add_cong @{thm if_cong}
#> Simplifier.add_cong @{thm split_cong}
#> Simplifier.add_cong @{thm HOL.conj_cong}
#> (if use_ugly_rules then
(fn ctxt => ctxt addsimps [@{thm split_def}])
else
I)
(*
* Solve a goal of the form:
*
* simp_expr P A ?X
*
* This is done by simplifing "A" while assuming "P", and unifying the result
* (usually instantiating "X") in the process.
*)
val simp_expr_thm =
@{lemma "(simp_expr P G G == simp_expr P G G') ==> simp_expr P G G'" by (clarsimp simp: simp_expr_def)}
fun solve_simp_expr_tac ctxt =
Subgoal.FOCUS_PARAMS (fn {context = ctxt, schematics = (_, term_schem), ...} =>
(fn thm =>
case Drule.cprems_of thm of
[] => (no_tac thm)
| (goal::_) =>
(case term_of goal of
(_ $ (Const (@{const_name "simp_expr"}, _) $ P $ L $ _)) =>
let
val goal = @{mk_term "simp_expr ?P ?L ?L" (P, L)} (P, L)
|> cterm_of (Proof_Context.theory_of ctxt)
val simplified = Simplifier.asm_full_rewrite (map_opt_simpset false ctxt) goal
(* Ensure that all schematics have been resolved. *)
val schematic_remains = Term.exists_subterm Term.is_Var (prop_of simplified)
in
if schematic_remains then
(rtac @{thm simp_expr_triv} 1) thm
else
((rtac simp_expr_thm 1) THEN (rtac simplified 1)) thm
end
| _ => no_tac thm)
)) ctxt
(*
* Solve a goal of the forms:
*
* simp_expr P A B
*
* where both "A" and "B" are constants (i.e., not schematics).
*)
fun solve_simp_expr_const_tac ctxt thm =
if (Term.exists_subterm Term.is_Var (term_of (cprem_of thm 1))) then
no_tac thm
else
SOLVES (
(rtac @{thm simp_expr_solve_constant} 1)
THEN (Clasimp.clarsimp_tac (map_opt_simpset true ctxt) 1)) thm
(*
* Given a theorem of the form:
*
* monad_equiv P L R Q E
*
* simplify "P", possibly trimming parts of it that are too large.
*
* The idea here is to avoid exponential blow-up by trimming off terms that get
* too large.
*)
fun simp_monad_equiv_pre_tac ctxt =
Subgoal.FOCUS_PARAMS (fn {context = ctxt, schematics = (_, term_schem), ...} =>
(fn thm =>
case term_of (cprem_of thm 1) of
Const (@{const_name Trueprop}, _) $
(Const (@{const_name monad_equiv}, _) $ P $ _ $ _ $ _ $ _) =>
let
val thy = Proof_Context.theory_of ctxt
(* Perform basic simplification of the term. *)
val simp_thm = Simplifier.asm_full_rewrite (map_opt_simpset false ctxt) (cterm_of thy P)
in
(rtac (@{thm monad_equiv_weaken_pre''} OF [simp_thm]) 1
ORELSE (fn t => raise (CTERM ("failed to prove goal", [cprem_of t 1])))) thm
end
| _ =>
all_tac thm
)) ctxt
(*
* Recursively simplify a monadic expression, using information gleaned from
* earlier in the program to simplify parts of the program further down.
*)
fun monad_equiv ctxt ct =
let
(* Mark context as being "invisible" to reduce warnings being printed. *)
val ctxt = Context_Position.set_visible false ctxt
val thy = Proof_Context.theory_of ctxt
(* Generate our top-level "monad_equiv" goal. *)
val goal = @{mk_term "?L == ?R" (L)} (term_of ct)
|> cterm_of thy
|> Goal.init
|> Utils.apply_tac "Creating object-level equality." (rtac @{thm eq_reflection} 1)
|> Utils.apply_tac "Creating 'monad_equiv' goal." (rtac @{thm monad_equiv_eq} 1)
(* Print a diagnostic if this branch fails. *)
val num_failures = ref 0
fun print_failure_tac t =
if (false andalso !num_failures < 5) then
(num_failures := !num_failures + 1; (print_tac "Branch failed" THEN no_tac) t)
else
(no_tac t)
(* Fetch theorms used in the simplification process. *)
val thms = L2FlowThms.get ctxt
(* Tactic to blindly apply simplification rules. *)
fun solve_goal_tac _ =
(simp_monad_equiv_pre_tac ctxt 1)
THEN DETERM (
SOLVES
((solve_simp_expr_const_tac ctxt)
ORELSE
((solve_simp_expr_tac ctxt 1)
ORELSE
((resolve_tac thms THEN_ALL_NEW solve_goal_tac) 1
ORELSE
((print_failure_tac))))))
(* Apply the rules. *)
val thm =
Utils.apply_tac "Simplifying L2" (solve_goal_tac 1) goal
|> Goal.finish ctxt
in
thm
end
(*
* A simproc implementing the "L2_gets_bind" rule. The rule, unfortunately, has
* the ability to cause exponential growth in the spec size in some cases;
* thus, we can only selectively apply it in cases where this doesn't happen.
*
* In particular, we propagate a "gets" into its usage if it is used only once.
*)
val l2_gets_bind_thm = mk_meta_eq @{thm L2_gets_bind}
fun l2_gets_bind_simproc' ctxt term =
let
fun is_simple (_ $ Abs (_, _, Bound _)) = true
| is_simple (_ $ Abs (_, _, Free _)) = true
| is_simple (_ $ Abs (_, _, Const _)) = true
| is_simple x = false
in
case term of
(Const (@{const_name "L2_seq"}, _) $ lhs $ Abs (n, _, rhs)) =>
let
fun count_var_usage (a $ b) = count_var_usage a + count_var_usage b
| count_var_usage (Abs (_, _, x)) = count_var_usage x
| count_var_usage (Free ("_dummy", dummyT)) = 1
| count_var_usage _ = 0
val count = count_var_usage (subst_bounds ([Free ("_dummy", dummyT)], rhs))
in
if count <= 1 orelse is_simple lhs then
SOME l2_gets_bind_thm
else
NONE
end
| _ => NONE
end
val l2_gets_bind_simproc =
Raw_Simplifier.mk_simproc "L2_gets_bind_simproc"
[@{cpat "L2_seq (L2_gets (%_. ?A) ?n) ?B"}] l2_gets_bind_simproc'
(* Simproc to clean up guards. *)
fun l2_guard_simproc' ss ctxt term =
let
val thy = Proof_Context.theory_of ctxt
val simp_thm = Simplifier.asm_full_rewrite
(Simplifier.add_cong @{thm HOL.conj_cong} (put_simpset ss ctxt)) (cterm_of thy term)
val [lhs, rhs] = Thm.prop_of (Drule.eta_contraction_rule simp_thm) |> Term.strip_comb |> snd
in
if Term_Ord.fast_term_ord (lhs, rhs) = EQUAL then
NONE
else
SOME simp_thm
end
fun l2_guard_simproc ss =
Raw_Simplifier.mk_simproc "L2_guard_simproc"
[@{cpat "L2_guard ?G"}] (l2_guard_simproc' ss)
(*
* Adjust "prod_case" commands so that constructs such as:
*
* while C (%x. gets (case x of (a, b) => %s. P a b)) ...
*
* are transformed into:
*
* while C (%(a, b). gets (%s. P a b)) ...
*)
fun fix_L2_while_loop_splits_conv ctxt =
Simplifier.asm_full_rewrite (
put_simpset HOL_ss ctxt
addsimps @{thms L2_split_fixups}
|> fold Simplifier.add_cong @{thms L2_split_fixups_congs})
(*
* Carry out flow-sensitive optimisations on the given 'thm'.
*
* 'n' is the argument number to cleanup, counting from 1. So for example, if
* our input theorem was "corres P A B", an "n" of 3 would simplify "A".
*
* If "fast_mode" is true, we don't do flow-sensitive optimisations (which tend
* to be time-consuming).
*)
fun cleanup_thm ctxt thm fast_mode n =
let
(* Don't print out warning messages. *)
val ctxt = Context_Position.set_visible false ctxt
(* Setup basic simplifier. *)
fun basic_ss ctxt =
put_simpset (simpset_of @{theory_context AutoCorresSimpset}) ctxt
|> (fn ctxt => ctxt addsimps (L2PeepholeThms.get ctxt))
|> (fn ctxt => ctxt addsimprocs [l2_gets_bind_simproc, l2_guard_simproc (simpset_of ctxt)])
|> map_opt_simpset false
fun simp_conv ctxt =
Drule.beta_eta_conversion
then_conv (fix_L2_while_loop_splits_conv ctxt)
then_conv (Simplifier.rewrite (basic_ss ctxt))
fun l2conv conv =
Utils.remove_meta_conv (fn ctxt => Utils.nth_arg_conv n (conv ctxt)) ctxt
(* Apply peephole optimisations to the theorem. *)
val new_thm =
Conv.fconv_rule (l2conv simp_conv) thm
|> Drule.eta_contraction_rule
(* Apply flow-sensitive optimisations, and then re-apply simple simplifications. *)
val new_thm =
if not fast_mode then
Conv.fconv_rule (
l2conv (fn ctxt =>
monad_equiv ctxt
then_conv (simp_conv ctxt)
)) new_thm
else
new_thm
(* Beta/Eta normalise. *)
val new_thm = Conv.fconv_rule (l2conv (K Drule.beta_eta_conversion)) new_thm
in
new_thm
end
end