-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmonad_convert.ML
215 lines (184 loc) · 7.46 KB
/
monad_convert.ML
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
(*
* Copyright 2014, NICTA
*
* This software may be distributed and modified according to the terms of
* the BSD 2-Clause license. Note that NO WARRANTY is provided.
* See "LICENSE_BSD2.txt" for details.
*
* @TAG(NICTA_BSD)
*)
(*
* Code to manage converting between L2_monad and other monad types.
*
* TypeStrengthen provides a higher level interface for converting entire programs.
*)
structure Monad_Convert = struct
(* Utilities. *)
fun intersperse _ [] = []
| intersperse _ [x] = [x]
| intersperse a (x::xs) = x :: a :: intersperse a xs
fun theE NONE exc = raise exc
| theE (SOME x) _ = x
fun oneE [] exc = raise exc
| oneE (x::_) _ = x
(* From Find_Theorems *)
fun apply_dummies tm =
let
val (xs, _) = Term.strip_abs tm;
val tm' = Term.betapplys (tm, map (Term.dummy_pattern o #2) xs);
in #1 (Term.replace_dummy_patterns tm' 1) end;
fun parse_pattern ctxt nm =
let
val consts = Proof_Context.consts_of ctxt;
val nm' =
(case Syntax.parse_term ctxt nm of
Const (c, _) => c
| _ => Consts.intern consts nm);
in
(case try (Consts.the_abbreviation consts) nm' of
SOME (_, rhs) => apply_dummies (Proof_Context.expand_abbrevs ctxt rhs)
| NONE => Proof_Context.read_term_pattern ctxt nm)
end;
(* Breadth-first term search *)
fun term_search_bf cont pred prune = let
fun fresh_var vars v = if member (op =) vars v then fresh_var vars (v ^ "'") else v
fun search ((vars, term), queue) =
if pred term then cont (vars, term) (fn () => walk queue) else
if prune term then walk queue else
case term of
t as Abs (v, typ, _) =>
let val v' = fresh_var vars v in
walk (Queue.enqueue
((v'::vars), betapply (t, Free (v', typ))) queue)
end
| f $ x => walk (Queue.enqueue (vars, x) (Queue.enqueue (vars, f) queue))
| _ => walk queue
and walk queue = if Queue.is_empty queue then () else search (Queue.dequeue queue)
in
(fn term => search (([], term), Queue.empty))
end
fun term_search_bf_first pred prune term = let
val r = Unsynchronized.ref NONE
val _ = term_search_bf (fn result => K (r := SOME result)) pred prune term
in !r end
fun grep_term ctxt pattern =
let
val thy = Proof_Context.theory_of ctxt
in
term_search_bf_first
(fn term => Pattern.matches thy (pattern, term))
(fn term => not (Pattern.matches_subterm thy (pattern, term)))
end
(* Check whether the term is in L2_monad notation. *)
val term_is_L2 = Monad_Types.check_lifting_head
[@{term "L2_unknown"}, @{term "L2_seq"}, @{term "L2_modify"},
@{term "L2_gets"}, @{term "L2_condition"}, @{term "L2_catch"}, @{term "L2_while"},
@{term "L2_throw"}, @{term "L2_spec"}, @{term "L2_guard"}, @{term "L2_fail"},
@{term "L2_recguard"}, @{term "L2_call"}]
(*
* Perform monad conversion on a term, taking into account any extra
* simplifying facts. Only a successful conversion is returned.
*
* For this conversion to be useful on recursive programs, it needs
* to be given a fact representing the inductive assumption.
*)
fun monad_rewrite (lthy : Proof.context) (mt : Monad_Types.monad_type)
(more_facts : thm list) (forward : bool)
(term : term) : thm option =
let
val lthy = Utils.set_hidden_ctxt lthy
val rules = if forward then #lift_rules mt else #unlift_rules mt
val rules' = dest_ss rules |> #simps |> map #2
val cterm = cterm_of (Proof_Context.theory_of lthy) term
(* Just apply the simplifier and hope that it works. *)
val thm = Simplifier.rewrite (
put_simpset HOL_ss lthy addsimps rules' addsimps more_facts) cterm
val rhs = Utils.rhs_of (term_of_thm thm)
val good_rewrite = if forward then #valid_term mt else term_is_L2
in
if good_rewrite lthy rhs
then SOME thm else NONE
end
(*
* Apply polish to a theorem of the form:
*
* <LHS> == <lift> $ <some term to polish>
*
* Return the new theorem.
*)
local
val prod_case_eta_contract_thm =
@{lemma "(%x. (prod_case s) x) == (prod_case s)" by simp}
in
fun polish ctxt (mt : Monad_Types.monad_type) thm =
let
(* Apply any polishing rules. *)
val ctxt = Utils.set_hidden_ctxt ctxt
val solver = Raw_Simplifier.mk_solver "simp solver"
(K (asm_full_simp_tac ctxt))
val simps = PolishSimps.get ctxt
(* Simplify using polish rules. *)
val simp_conv =
Simplifier.rewrite (put_simpset (#polish_rules mt) ctxt addsimps simps)
(* eta-contract "prod_case" clauses, so that they render as:
* "%(a, b). P a b" instead of "case x of (a, b) => P a b". *)
val prod_case_conv =
Conv.bottom_conv (
K (Conv.try_conv (Conv.rewr_conv prod_case_eta_contract_thm))) ctxt
val thm_p =
Conv.fconv_rule (Conv.arg_conv (Utils.rhs_conv (
simp_conv then_conv prod_case_conv))) thm
in
thm_p
end
end
(*
* monad_convert tactic.
*)
fun monad_convert_tac (ctxt : Proof.context) (monad_name : string)
(pattern_str : string) (n : int) : tactic =
fn state => let
val all_rules = Monad_Types.TSRules.get (Context.Proof ctxt)
(* Figure out which monad to lift into. *)
val target_rule = theE (Symtab.lookup all_rules monad_name)
(ERROR ("monad_convert: could not find monad type " ^ quote monad_name))
(* Search the subgoal for the supplied pattern. *)
val pattern = parse_pattern ctxt pattern_str
val subgoal = Logic.get_goal (term_of_thm state) n
val (m_vars, m_term) = theE (grep_term ctxt pattern subgoal)
(TERM ("monad_convert: failed to match pattern", [pattern]))
(* Find a lifting rule whose output matches m_term.
* This saves us from having to try every unlift rule. *)
val orig_lift_rule = oneE (filter (fn mt => #valid_term mt ctxt m_term)
(all_rules |> Symtab.dest |> map snd))
(TERM ("monad_convert: could not determine monad type", [m_term]))
(* Unlift back to L2_monad. *)
val unlift_thm = theE (monad_rewrite ctxt orig_lift_rule [] false m_term)
(TERM ("monad_convert: could not unlift term (rule: " ^
#name orig_lift_rule ^ ")", [m_term]))
val unlift_term = Utils.rhs_of (term_of_thm unlift_thm)
(* Lift to target monad. *)
val relift_thm = theE (monad_rewrite ctxt target_rule [] true unlift_term)
(TERM ("monad_convert: could not lift to " ^ #name target_rule,
[m_term, unlift_term]))
(* Polish result. *)
val relift_thm' = polish ctxt target_rule relift_thm
val translate_thm = Thm.transitive unlift_thm relift_thm'
(* Make variables schematic *)
val translate_thm' = Goal.prove ctxt (sort_distinct string_ord m_vars) []
(term_of_thm translate_thm) (K (rtac translate_thm 1))
val result = EqSubst.eqsubst_tac ctxt [0] [translate_thm'] n state
in
case Seq.pull result of
NONE => raise TERM ("monad_convert: failed to apply conversion",
[term_of_thm translate_thm', subgoal])
| SOME (x, xs) => Seq.cons x xs
end
val _ = Context.>> (Context.map_theory
(Method.setup (Binding.name "monad_convert")
(* Based on subgoal_tac parser *)
(Args.goal_spec -- Scan.lift (Parse.name -- Args.name_source) >>
(fn (quant, (monad_name, term_str)) => fn ctxt =>
SIMPLE_METHOD'' quant (monad_convert_tac ctxt monad_name term_str)))
"autocorres monad conversion"))
end