diff options
author | Son Ho | 2023-05-18 15:59:08 +0200 |
---|---|---|
committer | Son HO | 2023-06-04 21:54:38 +0200 |
commit | 8500a6ad00b21287850faa12d81a59fa46ca0848 (patch) | |
tree | df5cd2c31480566869e0558c512377718df88174 | |
parent | 7a39a2c8f7608ca7dd337fd7fc6ff9a56be33de0 (diff) |
Commit evalLib
-rw-r--r-- | backends/hol4/evalLib.sig | 26 | ||||
-rw-r--r-- | backends/hol4/evalLib.sml | 144 |
2 files changed, 170 insertions, 0 deletions
diff --git a/backends/hol4/evalLib.sig b/backends/hol4/evalLib.sig new file mode 100644 index 00000000..d045d3bb --- /dev/null +++ b/backends/hol4/evalLib.sig @@ -0,0 +1,26 @@ +signature evalLib = +sig + (* This module implements [eval_conv], which supersedes EVAL_CONV by + allowing to use custom unfolding theorems. This is particularly + useful for divDefLib, which returns rewriting theorems to the user + which are actually not definitional theorems. + *) + + include Abbrev + + (* The following functions allow to register custom unfolding theorems *) + (* TODO: permanence of theorems? *) + val add_unfold : term * thm -> unit + val add_unfolds : (term * thm) list -> unit + val add_unfold_thm : thm -> unit + val add_unfold_thms : thm list -> unit + + (* Get the unfolding theorems *) + val get_unfold_thms : unit -> thm list + + (* The custom "eval" conversion *) + val eval_conv : conv + + (* The custom "eval" function *) + val eval : term -> term +end diff --git a/backends/hol4/evalLib.sml b/backends/hol4/evalLib.sml new file mode 100644 index 00000000..30e28eed --- /dev/null +++ b/backends/hol4/evalLib.sml @@ -0,0 +1,144 @@ +structure evalLib = +struct + +open simpLib computeLib +open primitivesBaseTacLib + +(* Lexicographic order over pairs - TODO: this duplicates primitivesBaseTacLib *) +fun pair_compare (comp1 : 'a * 'a -> order) (comp2 : 'b * 'b -> order) + ((p1, p2) : (('a * 'b) * ('a * 'b))) : order = + let + val (x1, y1) = p1; + val (x2, y2) = p2; + in + case comp1 (x1, x2) of + LESS => LESS + | GREATER => GREATER + | EQUAL => comp2 (y1, y2) + end + +(* A constant name (theory, constant name) - TODO: this duplicates primitivesBaseTacLib *) +type const_name = string * string +val const_name_compare = pair_compare String.compare String.compare +fun get_const_name (tm : term) : const_name = + let + val {Thy=thy, Name=name, Ty=_} = dest_thy_const tm + in + (thy, name) + end + +(* TODO: should we rather use srw_ss ()? + TODO: permanence of saved theorems? (see BasicProvers.export_rewrites) *) +val custom_rewrites : ssfrag ref = ref empty_ssfrag +val custom_unfolds : ((const_name, thm) Redblackmap.dict) ref = ref (Redblackmap.mkDict const_name_compare) +val custom_unfolds_consts : (term Redblackset.set) ref = ref (Redblackset.empty Term.compare) + +fun add_rewrites (thms : thm list) : unit = + let + val rewrs = frag_rewrites (!custom_rewrites) + val _ = custom_rewrites := rewrites (List.@ (thms, rewrs)) + in + () + end + +fun add_unfold (tm : term, th : thm) : unit = + let + val _ = custom_unfolds := Redblackmap.insert (!custom_unfolds, get_const_name tm, th) + val _ = custom_unfolds_consts := Redblackset.add (!custom_unfolds_consts, tm) + in + () + end + +fun add_unfolds (ls : (term * thm) list) : unit = + app add_unfold ls + +fun get_custom_unfold_const (th : thm) : term = (fst o strip_comb o lhs o snd o strip_forall o concl) th +fun add_unfold_thm (th : thm) : unit = add_unfold (get_custom_unfold_const th, th) +fun add_unfold_thms (ls : thm list) : unit = app add_unfold_thm ls + +fun get_unfold_thms () : thm list = + map snd (Redblackmap.listItems (!custom_unfolds)) + +(* Apply a custom unfolding to the term, if possible. + + This conversion never fails nor raises exceptions. + *) +fun apply_custom_unfold (tm : term) : thm = + let + (* Remove all the matches to find the top-most scrutinee *) + val scrut = strip_all_cases_get_scrutinee tm + (* Find an unfolding to apply: we look at the application itself, then + all its arguments, then unfold it. + TODO: maybe do something more systematic, like CBV + + This function returns a theorem or raises an HOL_ERR + *) + fun find_unfolding (tm : term) : thm = + (* Try to find an unfolding on the current term *) + let + (* Deconstruct the constant and attempt to lookup an unfolding + theorem *) + val c = (fst o strip_comb) tm + val cname = get_const_name c + val unfold_th = Redblackmap.find (!custom_unfolds, cname) + handle Redblackmap.NotFound => failwith "No theorem found" + (* Instantiate the theorem *) + val unfold_th = SPEC_ALL unfold_th + val th_tm = (lhs o concl) unfold_th + val (subst, inst) = match_term th_tm tm + val unfold_th = INST subst (INST_TYPE inst unfold_th) + in + unfold_th + end + handle HOL_ERR _ => + (* Can't unfold the current term: decompose it and explore the subterms *) + if is_comb tm then + let + val tm = strip_all_cases_get_scrutinee tm + val (app, args) = strip_comb tm + val tml = List.rev (app :: args) + in + find_unfolding_in_list tml + end + else failwith "Found no constant on which to apply an unfolding" + and find_unfolding_in_list (ls : term list) : thm = + case ls of + [] => failwith "Found no constant on which to apply an unfolding" + | tm :: tl => + (* Explore the argument, if it fails explore the application *) + find_unfolding tm + handle HOL_ERR _ => find_unfolding_in_list tl + val unfold_th = find_unfolding scrut + in + (* Apply the theorem - for security, we apply it only once. + TODO: we might want to be more precise and apply it exactly where + we need to. + *) + PURE_ONCE_REWRITE_CONV [unfold_th] tm + end + handle HOL_ERR _ => REFL tm + +fun eval_conv tm = + let + (* TODO: optimize: we recompute the list each time... *) + val restr_tms = Redblackset.listItems (!custom_unfolds_consts) + (* We do the following: + - use the standard EVAL conv, but restrains it from unfolding terms for + which we have custom unfolding theorems + - apply custom unfoldings + - simplify + *) + val standard_eval = RESTR_EVAL_CONV restr_tms + val simp_no_fail_conv = (fn x => SIMP_CONV (srw_ss () ++ !custom_rewrites) [] x handle UNCHANGED => REFL x) + val one_step_conv = standard_eval THENC (apply_custom_unfold THENC simp_no_fail_conv) + (* Wrap the conversion such that it fails if the term is unchanged *) + fun one_step_changed_conv tm = (CHANGED_CONV one_step_conv) tm + in + (* Repeat *) + REPEATC one_step_changed_conv tm + end + handle UNCHANGED => REFL tm + +fun eval tm = (rhs o concl) (eval_conv tm) + +end |