summaryrefslogtreecommitdiff
path: root/backends/hol4/evalLib.sml
blob: f9e758c9d59ff1f1bf34ac086d4e528ae176db58 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
structure evalLib =
struct

open simpLib computeLib
open primitivesBaseTacLib

(* Lexicographic order over pairs  - TODO: this duplicates primitivesBaseTacLib *)
fun pair_compare (comp1 : 'a * 'a -> order) (comp2 : 'b * 'b -> order)
  ((p1, p2) : (('a * 'b) * ('a * 'b))) : order =
  let
    val (x1, y1) = p1;
    val (x2, y2) = p2;
  in
    case comp1 (x1, x2) of
      LESS => LESS
    | GREATER => GREATER
    | EQUAL => comp2 (y1, y2)
  end

(* A constant name (theory, constant name) - TODO: this duplicates primitivesBaseTacLib *)
type const_name = string * string
val const_name_compare = pair_compare String.compare String.compare
fun get_const_name (tm : term) : const_name =
  let
    val {Thy=thy, Name=name, Ty=_} = dest_thy_const tm
  in
    (thy, name)
  end

(* Create the persistent collection of rewrite theorems - TODO: more efficient
   collection than a list? *)
val (add_rewrite_thm, get_rewrite_thms) =
  let
    (* Small helper *)
    fun add_to_state (th : thm) s = th :: s

    (* Persistently update the maps given a delta  *)
    fun apply_delta (delta : ThmSetData.setdelta) st =
      case delta of
        ThmSetData.ADD (_, th) => add_to_state th st
      | ThmSetData.REMOVE _ =>
        raise mk_HOL_ERR "evalLib" "apply_delta" ("Unexpected REMOVE")

    (* Initialize the collection *)
    val init_state = []
    val {update_global_value, (* Local update *)
         record_delta, (* Global update *)
         get_global_value = get_rewrite_thms,
         ...} =
        ThmSetData.export_with_ancestry {
          settype = "evalLib_rewrite_theorems",
          delta_ops = {apply_to_global = apply_delta,
                       uptodate_delta = K true,
                       thy_finaliser = NONE,
                       initial_value = init_state,
                       apply_delta = apply_delta}
        }

    fun add_rewrite_thm (s : string) : unit =
      let
        val th = saveThmsLib.lookup_thm s
      in
        (* Record the delta - for persistence for the future sessions *)
        record_delta (ThmSetData.ADD th);
        (* Update the global value - for the current session: record_delta
           doesn't update the state of the current session *)
        update_global_value (add_to_state (snd th))
      end
  in
    (add_rewrite_thm, get_rewrite_thms)
  end

fun add_rewrite_thms (ls : string list) : unit =
  app add_rewrite_thm ls


(* Create the persistent collection of unfolding theorems, which is a pair:
   (set of constants, map from constant name to theorem)
 *)
type unfold_state = term Redblackset.set * (const_name, thm) Redblackmap.dict
fun get_custom_unfold_const (th : thm) : term = (fst o strip_comb o lhs o snd o strip_forall o concl) th

val (add_unfold_thm, get_unfold_thms_map) =
  let
    (* Small helper *)
    fun add_to_state (th : thm) (s, m) =
      let
        val k = get_custom_unfold_const th
        val name = get_const_name k
        val s = Redblackset.add (s, k)
        val m = Redblackmap.insert (m, name, th)
      in
        (s, m)
      end

    (* Persistently update the maps given a delta  *)
    fun apply_delta (delta : ThmSetData.setdelta) st =
      case delta of
        ThmSetData.ADD (_, th) => add_to_state th st
      | ThmSetData.REMOVE _ =>
        raise mk_HOL_ERR "evalLib" "apply_delta" ("Unexpected REMOVE")

    (* Initialize the collection *)
    val init_state = (Redblackset.empty compare, Redblackmap.mkDict const_name_compare)
    val {update_global_value, (* Local update *)
         record_delta, (* Global update *)
         get_global_value = get_unfold_thms_map,
         ...} =
        ThmSetData.export_with_ancestry {
          settype = "evalLib_unfold_theorems",
          delta_ops = {apply_to_global = apply_delta,
                       uptodate_delta = K true,
                       thy_finaliser = NONE,
                       initial_value = init_state,
                       apply_delta = apply_delta}
        }

    fun add_unfold_thm (s : string) : unit =
      let
        val th = saveThmsLib.lookup_thm s
      in
        (* Record the delta - for persistence for the future sessions *)
        record_delta (ThmSetData.ADD th);
        (* Update the global value - for the current session: record_delta
           doesn't update the state of the current session *)
        update_global_value (add_to_state (snd th))
      end
  in
    (add_unfold_thm, get_unfold_thms_map)
  end

fun add_unfold_thms (ls : string list) : unit =
  app add_unfold_thm ls

fun get_unfold_thms () : thm list =
  map snd (Redblackmap.listItems (snd (get_unfold_thms_map ())))

(* Apply a custom unfolding to the term, if possible.

   This conversion never fails nor raises exceptions.
 *)
fun apply_custom_unfold (tm : term) : thm =
  let
    (* Retrieve the custom unfoldings *)
    val custom_unfolds = snd (get_unfold_thms_map ())

    (* Remove all the matches to find the top-most scrutinee *)
    val scrut = strip_all_cases_get_scrutinee tm
    (* Find an unfolding to apply: we look at the application itself, then
       all its arguments, then unfold it.
       TODO: maybe do something more systematic, like CBV

       This function returns a theorem or raises an HOL_ERR
     *)
    fun find_unfolding (tm : term) : thm =
      (* Try to find an unfolding on the current term *)
      let
        (* Deconstruct the constant and attempt to lookup an unfolding
           theorem *)
        val c = (fst o strip_comb) tm
        val cname = get_const_name c
        val unfold_th = Redblackmap.find (custom_unfolds, cname)
          handle Redblackmap.NotFound => failwith "No theorem found"
        (* Instantiate the theorem *)
        val unfold_th = SPEC_ALL unfold_th
        val th_tm = (lhs o concl) unfold_th
        val (subst, inst) = match_term th_tm tm
        val unfold_th = INST subst (INST_TYPE inst unfold_th)
      in
        unfold_th
      end
      handle HOL_ERR _ =>
        (* Can't unfold the current term: decompose it and explore the subterms *)
        if is_comb tm then
          let
            val tm = strip_all_cases_get_scrutinee tm
            val (app, args) = strip_comb tm
            val tml = List.rev (app :: args)
          in
            find_unfolding_in_list tml
          end
        else failwith "Found no constant on which to apply an unfolding"
    and find_unfolding_in_list (ls : term list) : thm =
      case ls of
        [] => failwith "Found no constant on which to apply an unfolding"
      | tm :: tl =>
        (* Explore the argument, if it fails explore the application *)
          find_unfolding tm
          handle HOL_ERR _ => find_unfolding_in_list tl
    val unfold_th = find_unfolding scrut
  in
    (* Apply the theorem - for security, we apply it only once.
       TODO: we might want to be more precise and apply it exactly where
       we need to.
     *)
    PURE_ONCE_REWRITE_CONV [unfold_th] tm
  end
  handle HOL_ERR _ => REFL tm

fun eval_conv tm =
  let
    (* TODO: optimize: we recompute the list each time... *)
    val restr_tms = Redblackset.listItems (fst (get_unfold_thms_map ()))
    (* We do the following:
       - use the standard EVAL conv, but restrains it from unfolding terms for
         which we have custom unfolding theorems
       - apply custom unfoldings
       - simplify
     *)
    val standard_eval = RESTR_EVAL_CONV restr_tms
    val simp_no_fail_conv = (fn x => SIMP_CONV (srw_ss ()) (get_rewrite_thms ()) x handle UNCHANGED => REFL x)
    val one_step_conv = standard_eval THENC (apply_custom_unfold THENC simp_no_fail_conv)
    (* Wrap the conversion such that it fails if the term is unchanged *)
    fun one_step_changed_conv tm = (CHANGED_CONV one_step_conv) tm
  in
    (* Repeat *)
    REPEATC one_step_changed_conv tm
  end
  handle UNCHANGED => REFL tm

fun eval tm = (rhs o concl) (eval_conv tm)

end