summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-05-18 15:59:08 +0200
committerSon HO2023-06-04 21:54:38 +0200
commit8500a6ad00b21287850faa12d81a59fa46ca0848 (patch)
treedf5cd2c31480566869e0558c512377718df88174
parent7a39a2c8f7608ca7dd337fd7fc6ff9a56be33de0 (diff)
Commit evalLib
Diffstat (limited to '')
-rw-r--r--backends/hol4/evalLib.sig26
-rw-r--r--backends/hol4/evalLib.sml144
2 files changed, 170 insertions, 0 deletions
diff --git a/backends/hol4/evalLib.sig b/backends/hol4/evalLib.sig
new file mode 100644
index 00000000..d045d3bb
--- /dev/null
+++ b/backends/hol4/evalLib.sig
@@ -0,0 +1,26 @@
+signature evalLib =
+sig
+ (* This module implements [eval_conv], which supersedes EVAL_CONV by
+ allowing to use custom unfolding theorems. This is particularly
+ useful for divDefLib, which returns rewriting theorems to the user
+ which are actually not definitional theorems.
+ *)
+
+ include Abbrev
+
+ (* The following functions allow to register custom unfolding theorems *)
+ (* TODO: permanence of theorems? *)
+ val add_unfold : term * thm -> unit
+ val add_unfolds : (term * thm) list -> unit
+ val add_unfold_thm : thm -> unit
+ val add_unfold_thms : thm list -> unit
+
+ (* Get the unfolding theorems *)
+ val get_unfold_thms : unit -> thm list
+
+ (* The custom "eval" conversion *)
+ val eval_conv : conv
+
+ (* The custom "eval" function *)
+ val eval : term -> term
+end
diff --git a/backends/hol4/evalLib.sml b/backends/hol4/evalLib.sml
new file mode 100644
index 00000000..30e28eed
--- /dev/null
+++ b/backends/hol4/evalLib.sml
@@ -0,0 +1,144 @@
+structure evalLib =
+struct
+
+open simpLib computeLib
+open primitivesBaseTacLib
+
+(* Lexicographic order over pairs - TODO: this duplicates primitivesBaseTacLib *)
+fun pair_compare (comp1 : 'a * 'a -> order) (comp2 : 'b * 'b -> order)
+ ((p1, p2) : (('a * 'b) * ('a * 'b))) : order =
+ let
+ val (x1, y1) = p1;
+ val (x2, y2) = p2;
+ in
+ case comp1 (x1, x2) of
+ LESS => LESS
+ | GREATER => GREATER
+ | EQUAL => comp2 (y1, y2)
+ end
+
+(* A constant name (theory, constant name) - TODO: this duplicates primitivesBaseTacLib *)
+type const_name = string * string
+val const_name_compare = pair_compare String.compare String.compare
+fun get_const_name (tm : term) : const_name =
+ let
+ val {Thy=thy, Name=name, Ty=_} = dest_thy_const tm
+ in
+ (thy, name)
+ end
+
+(* TODO: should we rather use srw_ss ()?
+ TODO: permanence of saved theorems? (see BasicProvers.export_rewrites) *)
+val custom_rewrites : ssfrag ref = ref empty_ssfrag
+val custom_unfolds : ((const_name, thm) Redblackmap.dict) ref = ref (Redblackmap.mkDict const_name_compare)
+val custom_unfolds_consts : (term Redblackset.set) ref = ref (Redblackset.empty Term.compare)
+
+fun add_rewrites (thms : thm list) : unit =
+ let
+ val rewrs = frag_rewrites (!custom_rewrites)
+ val _ = custom_rewrites := rewrites (List.@ (thms, rewrs))
+ in
+ ()
+ end
+
+fun add_unfold (tm : term, th : thm) : unit =
+ let
+ val _ = custom_unfolds := Redblackmap.insert (!custom_unfolds, get_const_name tm, th)
+ val _ = custom_unfolds_consts := Redblackset.add (!custom_unfolds_consts, tm)
+ in
+ ()
+ end
+
+fun add_unfolds (ls : (term * thm) list) : unit =
+ app add_unfold ls
+
+fun get_custom_unfold_const (th : thm) : term = (fst o strip_comb o lhs o snd o strip_forall o concl) th
+fun add_unfold_thm (th : thm) : unit = add_unfold (get_custom_unfold_const th, th)
+fun add_unfold_thms (ls : thm list) : unit = app add_unfold_thm ls
+
+fun get_unfold_thms () : thm list =
+ map snd (Redblackmap.listItems (!custom_unfolds))
+
+(* Apply a custom unfolding to the term, if possible.
+
+ This conversion never fails nor raises exceptions.
+ *)
+fun apply_custom_unfold (tm : term) : thm =
+ let
+ (* Remove all the matches to find the top-most scrutinee *)
+ val scrut = strip_all_cases_get_scrutinee tm
+ (* Find an unfolding to apply: we look at the application itself, then
+ all its arguments, then unfold it.
+ TODO: maybe do something more systematic, like CBV
+
+ This function returns a theorem or raises an HOL_ERR
+ *)
+ fun find_unfolding (tm : term) : thm =
+ (* Try to find an unfolding on the current term *)
+ let
+ (* Deconstruct the constant and attempt to lookup an unfolding
+ theorem *)
+ val c = (fst o strip_comb) tm
+ val cname = get_const_name c
+ val unfold_th = Redblackmap.find (!custom_unfolds, cname)
+ handle Redblackmap.NotFound => failwith "No theorem found"
+ (* Instantiate the theorem *)
+ val unfold_th = SPEC_ALL unfold_th
+ val th_tm = (lhs o concl) unfold_th
+ val (subst, inst) = match_term th_tm tm
+ val unfold_th = INST subst (INST_TYPE inst unfold_th)
+ in
+ unfold_th
+ end
+ handle HOL_ERR _ =>
+ (* Can't unfold the current term: decompose it and explore the subterms *)
+ if is_comb tm then
+ let
+ val tm = strip_all_cases_get_scrutinee tm
+ val (app, args) = strip_comb tm
+ val tml = List.rev (app :: args)
+ in
+ find_unfolding_in_list tml
+ end
+ else failwith "Found no constant on which to apply an unfolding"
+ and find_unfolding_in_list (ls : term list) : thm =
+ case ls of
+ [] => failwith "Found no constant on which to apply an unfolding"
+ | tm :: tl =>
+ (* Explore the argument, if it fails explore the application *)
+ find_unfolding tm
+ handle HOL_ERR _ => find_unfolding_in_list tl
+ val unfold_th = find_unfolding scrut
+ in
+ (* Apply the theorem - for security, we apply it only once.
+ TODO: we might want to be more precise and apply it exactly where
+ we need to.
+ *)
+ PURE_ONCE_REWRITE_CONV [unfold_th] tm
+ end
+ handle HOL_ERR _ => REFL tm
+
+fun eval_conv tm =
+ let
+ (* TODO: optimize: we recompute the list each time... *)
+ val restr_tms = Redblackset.listItems (!custom_unfolds_consts)
+ (* We do the following:
+ - use the standard EVAL conv, but restrains it from unfolding terms for
+ which we have custom unfolding theorems
+ - apply custom unfoldings
+ - simplify
+ *)
+ val standard_eval = RESTR_EVAL_CONV restr_tms
+ val simp_no_fail_conv = (fn x => SIMP_CONV (srw_ss () ++ !custom_rewrites) [] x handle UNCHANGED => REFL x)
+ val one_step_conv = standard_eval THENC (apply_custom_unfold THENC simp_no_fail_conv)
+ (* Wrap the conversion such that it fails if the term is unchanged *)
+ fun one_step_changed_conv tm = (CHANGED_CONV one_step_conv) tm
+ in
+ (* Repeat *)
+ REPEATC one_step_changed_conv tm
+ end
+ handle UNCHANGED => REFL tm
+
+fun eval tm = (rhs o concl) (eval_conv tm)
+
+end