From 1f0e5b3cb80e9334b07bf4b074c01150f4abd49d Mon Sep 17 00:00:00 2001
From: Son Ho
Date: Mon, 22 May 2023 15:23:48 +0200
Subject: Make the unfolding theorems collection from evalLib persistent

---
 backends/hol4/divDefLib.sml            |   2 +
 backends/hol4/divDefLibTestScript.sml  |  16 ++
 backends/hol4/divDefLibTestTheory.sig  |   9 ++
 backends/hol4/evalLib.sig              |   5 +-
 backends/hol4/evalLib.sml              |  73 ++++++---
 backends/hol4/primitivesBaseTacLib.sml |  22 ++-
 backends/hol4/primitivesLib.sml        |  15 ++
 backends/hol4/primitivesScript.sml     |  47 ++----
 backends/hol4/saveThmsLib.sml          | 275 +++++++++++----------------------
 9 files changed, 222 insertions(+), 242 deletions(-)

(limited to 'backends/hol4')

diff --git a/backends/hol4/divDefLib.sml b/backends/hol4/divDefLib.sml
index 1e95b6c5..c9d8806d 100644
--- a/backends/hol4/divDefLib.sml
+++ b/backends/hol4/divDefLib.sml
@@ -904,6 +904,8 @@ fun DefineDiv (def_qt : term quotation) =
        (we still do it: it doesn't cost much). *)
     val _ = app delete_binding thm_names
     val _ = map store_definition (zip thm_names def_eqs)
+    (* Also save the custom unfoldings, for evaluation (unit tests) *)
+    val _ = evalLib.add_unfold_thms thm_names
   in
     def_eqs
   end
diff --git a/backends/hol4/divDefLibTestScript.sml b/backends/hol4/divDefLibTestScript.sml
index 2e6d56b6..b01ec053 100644
--- a/backends/hol4/divDefLibTestScript.sml
+++ b/backends/hol4/divDefLibTestScript.sml
@@ -13,6 +13,21 @@ Datatype:
   | ListNil
 End
 
+(* A version of [nth] which doesn't use machine integers *)
+val [nth0_def] = DefineDiv ‘
+  nth0 (ls : 't list_t) (i : int) : 't result =
+    case ls of
+    | ListCons x tl =>
+      if i = (0:int)
+      then (Return x)
+      else
+        do
+        nth0 tl (i - 1)
+        od
+    | ListNil => Fail Failure
+’
+val _ = primitivesLib.assert_return “nth0 (ListCons 0 ListNil) 0”
+
 val [nth_def] = DefineDiv ‘
   nth (ls : 't list_t) (i : u32) : 't result =
     case ls of
@@ -26,6 +41,7 @@ val [nth_def] = DefineDiv ‘
         od
     | ListNil => Fail Failure
 ’
+val _ = primitivesLib.assert_return “nth (ListCons 0 ListNil) (int_to_u32 0)”
 
 (* even, odd *)
 
diff --git a/backends/hol4/divDefLibTestTheory.sig b/backends/hol4/divDefLibTestTheory.sig
index 526f74a6..d8cc4ab5 100644
--- a/backends/hol4/divDefLibTestTheory.sig
+++ b/backends/hol4/divDefLibTestTheory.sig
@@ -13,6 +13,7 @@ sig
     val list_t_size_def : thm
     val node_TY_DEF : thm
     val node_case_def : thm
+    val nth0_def : thm
     val nth_def : thm
     val odd_def : thm
     val tree_TY_DEF : thm
@@ -180,6 +181,14 @@ sig
       
       ⊢ ∀a f. node_CASE (Node a) f = f a
    
+   [nth0_def]  Definition
+      
+      ⊢ ∀ls i.
+          nth0 ls i =
+          case ls of
+            ListCons x tl => if i = 0 then Return x else nth0 tl (i − 1)
+          | ListNil => Fail Failure
+   
    [nth_def]  Definition
       
       ⊢ ∀ls i.
diff --git a/backends/hol4/evalLib.sig b/backends/hol4/evalLib.sig
index d045d3bb..f44cff13 100644
--- a/backends/hol4/evalLib.sig
+++ b/backends/hol4/evalLib.sig
@@ -8,10 +8,7 @@ sig
 
   include Abbrev
 
-  (* The following functions allow to register custom unfolding theorems *)
-  (* TODO: permanence of theorems? *)
-  val add_unfold : term * thm -> unit
-  val add_unfolds : (term * thm) list -> unit
+  (* The following functions allow to *persistently* register custom unfolding theorems *)
   val add_unfold_thm : thm -> unit
   val add_unfold_thms : thm list -> unit
 
diff --git a/backends/hol4/evalLib.sml b/backends/hol4/evalLib.sml
index 30e28eed..85a6b94a 100644
--- a/backends/hol4/evalLib.sml
+++ b/backends/hol4/evalLib.sml
@@ -27,37 +27,61 @@ fun get_const_name (tm : term) : const_name =
     (thy, name)
   end
 
-(* TODO: should we rather use srw_ss ()?
-   TODO: permanence of saved theorems? (see BasicProvers.export_rewrites) *)
-val custom_rewrites : ssfrag ref = ref empty_ssfrag
-val custom_unfolds : ((const_name, thm) Redblackmap.dict) ref = ref (Redblackmap.mkDict const_name_compare)
-val custom_unfolds_consts : (term Redblackset.set) ref = ref (Redblackset.empty Term.compare)
+(* Create the persistent collection, which is a pair:
+   (set of constants, map from constant name to theorem)
+ *)
+type state = term Redblackset.set * (const_name, thm) Redblackmap.dict
+fun get_custom_unfold_const (th : thm) : term = (fst o strip_comb o lhs o snd o strip_forall o concl) th
 
-fun add_rewrites (thms : thm list) : unit =
+(* Small helper *)
+fun add_to_state (th : thm) (s, m) =
   let
-    val rewrs = frag_rewrites (!custom_rewrites)
-    val _ = custom_rewrites := rewrites (List.@ (thms,  rewrs))
+    val k = get_custom_unfold_const th
+    val name = get_const_name k
+    val s = Redblackset.add (s, k)
+    val m = Redblackmap.insert (m, name, th)
   in
-    ()
+    (s, m)
   end
 
-fun add_unfold (tm : term, th : thm) : unit =
+(* Persistently update the maps given a delta  *)
+fun apply_delta (delta : ThmSetData.setdelta) st =
+  case delta of
+    ThmSetData.ADD (_, th) => add_to_state th st
+  | ThmSetData.REMOVE _ =>
+    raise mk_HOL_ERR "saveThmsLib" "create_map" ("Unexpected REMOVE")
+
+(* Initialize the collection *)
+val init_state = (Redblackset.empty compare, Redblackmap.mkDict const_name_compare)
+val {update_global_value, (* Local update *)
+     record_delta, (* Global update *)
+     get_global_value,
+     ...} =
+    ThmSetData.export_with_ancestry {
+      settype = "custom_unfold_theorems",
+      delta_ops = {apply_to_global = apply_delta,
+                   uptodate_delta = K true,
+                   thy_finaliser = NONE,
+                   initial_value = init_state,
+                   apply_delta = apply_delta}
+    }
+
+fun add_unfold_thm (s : string) : unit =
   let
-    val _ = custom_unfolds := Redblackmap.insert (!custom_unfolds, get_const_name tm, th)
-    val _ = custom_unfolds_consts := Redblackset.add (!custom_unfolds_consts, tm)
+    val th = saveThmsLib.lookup_thm s
   in
-    ()
+    (* Record the delta - for persistence for the future sessions *)
+      record_delta (ThmSetData.ADD th);
+    (* Update the global value - for the current session: record_delta
+       doesn't update the state of the current session *)
+    update_global_value (add_to_state (snd th))
   end
 
-fun add_unfolds (ls : (term * thm) list) : unit =
-  app add_unfold ls
-
-fun get_custom_unfold_const (th : thm) : term = (fst o strip_comb o lhs o snd o strip_forall o concl) th
-fun add_unfold_thm (th : thm) : unit = add_unfold (get_custom_unfold_const th, th)
-fun add_unfold_thms (ls : thm list) : unit = app add_unfold_thm ls
+fun add_unfold_thms (ls : string list) : unit =
+  app add_unfold_thm ls
 
 fun get_unfold_thms () : thm list =
-  map snd (Redblackmap.listItems (!custom_unfolds))
+  map snd (Redblackmap.listItems (snd (get_global_value ())))
 
 (* Apply a custom unfolding to the term, if possible.
 
@@ -65,6 +89,9 @@ fun get_unfold_thms () : thm list =
  *)
 fun apply_custom_unfold (tm : term) : thm =
   let
+    (* Retrieve the custom unfoldings *)
+    val custom_unfolds = snd (get_global_value ())
+
     (* Remove all the matches to find the top-most scrutinee *)
     val scrut = strip_all_cases_get_scrutinee tm
     (* Find an unfolding to apply: we look at the application itself, then
@@ -80,7 +107,7 @@ fun apply_custom_unfold (tm : term) : thm =
            theorem *)
         val c = (fst o strip_comb) tm
         val cname = get_const_name c
-        val unfold_th = Redblackmap.find (!custom_unfolds, cname)
+        val unfold_th = Redblackmap.find (custom_unfolds, cname)
           handle Redblackmap.NotFound => failwith "No theorem found"
         (* Instantiate the theorem *)
         val unfold_th = SPEC_ALL unfold_th
@@ -121,7 +148,7 @@ fun apply_custom_unfold (tm : term) : thm =
 fun eval_conv tm =
   let
     (* TODO: optimize: we recompute the list each time... *)
-    val restr_tms = Redblackset.listItems (!custom_unfolds_consts)
+    val restr_tms = Redblackset.listItems (fst (get_global_value ()))
     (* We do the following:
        - use the standard EVAL conv, but restrains it from unfolding terms for
          which we have custom unfolding theorems
@@ -129,7 +156,7 @@ fun eval_conv tm =
        - simplify
      *)
     val standard_eval = RESTR_EVAL_CONV restr_tms
-    val simp_no_fail_conv = (fn x => SIMP_CONV (srw_ss () ++ !custom_rewrites) [] x handle UNCHANGED => REFL x)
+    val simp_no_fail_conv = (fn x => SIMP_CONV (srw_ss ()) [] x handle UNCHANGED => REFL x)
     val one_step_conv = standard_eval THENC (apply_custom_unfold THENC simp_no_fail_conv)
     (* Wrap the conversion such that it fails if the term is unchanged *)
     fun one_step_changed_conv tm = (CHANGED_CONV one_step_conv) tm
diff --git a/backends/hol4/primitivesBaseTacLib.sml b/backends/hol4/primitivesBaseTacLib.sml
index 3d6d9e3e..143e25ce 100644
--- a/backends/hol4/primitivesBaseTacLib.sml
+++ b/backends/hol4/primitivesBaseTacLib.sml
@@ -648,9 +648,27 @@ fun strip_all_cases_get_scrutinee (t : term) : term =
        For instance: (fst o strip_case) “if i = 0 then ... else ...”
        returns “i” while we want to get “i = 0”.
 
-       We use [dest_case] for this reason.
+       Also, [dest_case] sometimes fails.
+       
+       Ex.:
+       {[
+         val t = “result_CASE (if T then Return 0 else Fail Failure) (λy. Return ()) Fail Diverge”
+         dest_case t
+       ]}
+       TODO: file an issue
+
+       We use a custom function [get_case_scrutinee] instead of [dest_case] for this reason.
      *)
-    (strip_all_cases_get_scrutinee o (fn (_, x, _) => x) o TypeBase.dest_case) t
+    let
+      fun get_case_scrutinee t =
+        let
+          val (_, tms) = strip_comb t
+        in
+          hd tms
+        end
+    in
+    (strip_all_cases_get_scrutinee o get_case_scrutinee) t
+    end
   else t
 
 (*
diff --git a/backends/hol4/primitivesLib.sml b/backends/hol4/primitivesLib.sml
index eed50e25..776843bf 100644
--- a/backends/hol4/primitivesLib.sml
+++ b/backends/hol4/primitivesLib.sml
@@ -511,4 +511,19 @@ val progress : tactic =
     map_first_tac progress_with thl (asms, g)
   end
 
+(* Small utility: check that a term evaluates to “Return” (used by the unit tests) *)
+fun assert_return (tm0 : term) : unit =
+  let
+    (* Evaluate the term *)
+    val tm = evalLib.eval tm0
+    (* Deconstruct it *)
+    val (app, _) = strip_comb tm
+    val {Thy, Name, ...} = dest_thy_const app
+    handle HOL_ERR _ => raise (mk_HOL_ERR "primitivesLib" "assert_return" ("The term doesn't evaluate to “Return ...”: " ^ term_to_string tm ^ "\n, final result: " ^ term_to_string tm))
+  in
+    if Thy = "primitives" andalso Name = "Return" then ()
+    else
+      raise (mk_HOL_ERR "primitivesLib" "assert_return" ("The term doesn't evaluate to “Return ...”: " ^ term_to_string tm ^ "\n, final result: " ^ term_to_string tm))
+  end
+
 end
diff --git a/backends/hol4/primitivesScript.sml b/backends/hol4/primitivesScript.sml
index 6f54fbfc..e10ce7e5 100644
--- a/backends/hol4/primitivesScript.sml
+++ b/backends/hol4/primitivesScript.sml
@@ -368,7 +368,21 @@ val all_int_to_scalar_to_int_unfold_lemmas = [
   u64_to_int_int_to_u64_unfold,
   u128_to_int_int_to_u128_unfold
 ]
-val _ = evalLib.add_unfold_thms (all_int_to_scalar_to_int_unfold_lemmas)
+
+val _ = evalLib.add_unfold_thms [
+  "isize_to_int_int_to_isize_unfold",
+  "i8_to_int_int_to_i8_unfold",
+  "i16_to_int_int_to_i16_unfold",
+  "i32_to_int_int_to_i32_unfold",
+  "i64_to_int_int_to_i64_unfold",
+  "i128_to_int_int_to_i128_unfold",
+  "usize_to_int_int_to_usize_unfold",
+  "u8_to_int_int_to_u8_unfold",
+  "u16_to_int_int_to_u16_unfold",
+  "u32_to_int_int_to_u32_unfold",
+  "u64_to_int_int_to_u64_unfold",
+  "u128_to_int_int_to_u128_unfold"
+]
 
 val int_to_i8_i8_to_int       = new_axiom ("int_to_i8_i8_to_int",       “∀i. int_to_i8 (i8_to_int i) = i”)
 val int_to_i16_i16_to_int     = new_axiom ("int_to_i16_i16_to_int",     “∀i. int_to_i16 (i16_to_int i) = i”)
@@ -1614,7 +1628,7 @@ Theorem mk_vec_unfold:
 Proof
   metis_tac [mk_vec_axiom]
 QED
-val _ = evalLib.add_unfold_thm mk_vec_unfold
+val _ = evalLib.add_unfold_thm "mk_vec_unfold"
 
 (* Defining ‘vec_insert_back’ *)
 val vec_insert_back_def = Define ‘
@@ -1645,33 +1659,4 @@ Proof
   sg_dep_rewrite_all_tac index_update_same >- cooper_tac >> fs []
 QED
 
-(* TODO: add theorems to the rewriting theorems
-from listSimps.sml:
-
-val LIST_ss = BasicProvers.thy_ssfrag "list"
-val _ = BasicProvers.logged_addfrags {thyname="list"} [LIST_EQ_ss]
-
-val list_rws = computeLib.add_thms
-  [
-   ALL_DISTINCT, APPEND, APPEND_NIL, CONS_11, DROP_compute, EL_restricted,
-   EL_simp_restricted, EVERY_DEF, EXISTS_DEF, FILTER, FIND_def, FLAT, FOLDL,
-   FOLDR, FRONT_DEF, GENLIST_AUX_compute, GENLIST_NUMERALS, HD, INDEX_FIND_def,
-   INDEX_OF_def, LAST_compute, LENGTH, LEN_DEF, LIST_APPLY_def, LIST_BIND_def,
-   LIST_IGNORE_BIND_def, LIST_LIFT2_def, LIST_TO_SET_THM, LLEX_def, LRC_def,
-   LUPDATE_compute, MAP, MAP2, NOT_CONS_NIL, NOT_NIL_CONS, NULL_DEF, oEL_def,
-   oHD_def,
-   PAD_LEFT, PAD_RIGHT, REVERSE_REV, REV_DEF, SHORTLEX_def, SNOC, SUM_ACC_DEF,
-   SUM_SUM_ACC,
-   TAKE_compute, TL, UNZIP, ZIP, computeLib.lazyfy_thm list_case_compute,
-   dropWhile_def, isPREFIX, list_size_def, nub_def, splitAtPki_def
-  ]
-
-fun list_compset () =
-   let
-      val base = reduceLib.num_compset()
-   in
-      list_rws base; base
-   end
-*)
-
 val _ = export_theory ()
diff --git a/backends/hol4/saveThmsLib.sml b/backends/hol4/saveThmsLib.sml
index d2523eeb..76b428cf 100644
--- a/backends/hol4/saveThmsLib.sml
+++ b/backends/hol4/saveThmsLib.sml
@@ -1,196 +1,107 @@
-structure saveThmsLib =
+structure saveThmsLib :> saveThmsLib =
 struct
 
-type simpset = simpLib.simpset;
+open HolKernel Abbrev
 
-open HolKernel boolLib markerLib;
+type thname = KernelSig.kernelname
 
-val op++ = simpLib.++;
-val op-* = simpLib.-*;
-
-val ERR = mk_HOL_ERR "saveThmsLib";
-
-fun add_simpls tyinfo ss =
-    (ss ++ simpLib.tyi_to_ssdata tyinfo) handle HOL_ERR _ => ss
-
-(* TODO: what is this? *)
-fun tyinfol() = TypeBasePure.listItems (TypeBase.theTypeBase())
-
-datatype stc_update = ADD_SSFRAG of simpLib.ssfrag | REMOVE_RWT of string
-type stc_state = simpset * bool * stc_update list
-  (* theorems, initialised-flag, update list (most recent first) *)
-
-val initial_simpset = simpLib.empty_ss
-fun ssf1 nth = simpLib.empty_ssfrag |> simpLib.add_named_rwt nth
-
-val state0 : stc_state = (initial_simpset, false, [])
-fun apply_delta d ((sset,initp,upds):stc_state) : stc_state =
-    case d of
-        ThmSetData.ADD nth =>
-        (sset ++ ssf1 nth, true, [])
-      | ThmSetData.REMOVE s => (sset -* [s], true, [])
-
-fun apply_stc_update (ADD_SSFRAG ssf, ss) = ss ++ ssf
-  | apply_stc_update (REMOVE_RWT n, ss) = ss -* [n]
-
-(* A stateful theorems collection *)
-datatype stc = STC_CON of {
-    name        : string,
-    thy_ssfrag  : string -> simpLib.ssfrag,
-    thy_simpset : string -> simpset option,
-    get_ss      : unit -> simpset,
-    export_thms : string list -> unit
+(* The user-provided functions *)
+type 'key key_fns = {
+  compare : 'key * 'key -> order,
+  get_key_from_thm : thm -> 'key
 }
 
-(* Create a stateful theorems collection *)
-fun create_stateful_theorem_set (stc_name : string) =
-  let
-(* val stc_name = "testStc" *)
-
-    fun init_state (st as (sset,initp,upds)) =
-        if initp then st
-        else
-          let fun init() =
-                  (List.foldl apply_stc_update sset (List.rev upds)
-                              |> rev_itlist add_simpls (tyinfol()),
-                   true, [])
-          in
-            HOL_PROGRESS_MESG ("Initialising STC simpset: " ^ stc_name ^ " ... ", "done") init ()
-          end
-
-    fun opt_partition f g ls =
-        let
-          fun recurse As Bs ls =
-              case ls of
-                  [] => (List.rev As, List.rev Bs)
-                | h::t => (case f h of
-                               SOME a => recurse (a::As) Bs t
-                             | NONE => (case g h of
-                                            SOME b => recurse As (b::Bs) t
-                                         | NONE => recurse As Bs t))
-        in
-          recurse [] [] ls
-        end
-
-    (* stale-ness is important for derived values. Derived values will get
-       re-calculated if their flag is true when the value is requested.
-    *)
-    val stale_flags = Sref.new ([] : bool Sref.t list)
-    fun notify () =
-        List.app (fn br => Sref.update br (K true)) (Sref.value stale_flags)
-
-    fun apply_to_global d (st as (sset,initp,upds):stc_state) : stc_state =
-        if not initp then
-          case d of
-              ThmSetData.ADD nth =>
-              let
-                open simpLib
-                val upds' =
-                    case upds of
-                        ADD_SSFRAG ssf :: rest =>
-                        ADD_SSFRAG (add_named_rwt nth ssf) :: rest
-                      | _ => ADD_SSFRAG (ssf1 nth) :: upds
-              in
-                (sset, initp, upds')
-              end
-            | ThmSetData.REMOVE s => (sset, initp, REMOVE_RWT s :: upds)
-        else
-          apply_delta d st before notify()
-
-    fun finaliser {thyname} deltas (sset,initp,upds) =
-      let
-        fun toNamedAdd (ThmSetData.ADD p) = SOME p | toNamedAdd _ = NONE
-        fun toRM (ThmSetData.REMOVE s) = SOME s | toRM _ = NONE
-        val (adds,rms) = opt_partition toNamedAdd toRM deltas
-        val ssfrag = simpLib.named_rewrites_with_names thyname (List.rev adds)
-          (* List.rev here preserves old behaviour wrt to the way theorems were
-             added to the global simpset; it will only make a difference when
-             overall rewrite system is not confluent *)
-        val new_upds = ADD_SSFRAG ssfrag :: map REMOVE_RWT rms
-      in
-        if initp then
-          (List.foldl apply_stc_update sset new_upds, true, []) before notify()
-        else (sset, false, List.revAppend(new_upds, upds))
-      end
-
-
-    val adresult as {DB,get_global_value,record_delta,update_global_value,...} =
-      ThmSetData.export_with_ancestry {
-        delta_ops = {
-          apply_delta = apply_delta,
-          apply_to_global = apply_to_global,
-          thy_finaliser = SOME finaliser,
-          initial_value = state0, uptodate_delta = K true
-        },
-        settype = stc_name
-      }
-
-    val get_deltas = #get_deltas adresult
-
-    (*
-    (* TODO: what is this? *)
-    fun update_fn tyi =
-      augment_stc_ss ([simpLib.tyi_to_ssdata tyi] handle HOL_ERR _ => [])
-
-    val () = TypeBase.register_update_fn (fn tyi => (update_fn tyi; tyi))
-    *)
-
-    fun get_ss () =
-        (update_global_value init_state;
-         #1 (get_global_value()))
-
-    fun export_thms slist =
-      let val ds = map ThmSetData.mk_add slist
-      in
-        List.app record_delta ds;
-        update_global_value (rev_itlist apply_to_global ds)
-      end
+(* The functions we return to the user to manipulate the map *)
+type 'key map_fns = {
+  (* Persistently save a theorem *)
+  save_thm : string -> unit,
+  (* Temporarily save a theorem *)
+  temp_save_thm : thm -> unit,
+  (* Get the key set *)
+  get_keys : unit -> 'key Redblackset.set,
+  (* Get the theorems map *)
+  get_map : unit -> ('key, thm) Redblackmap.dict
+}
 
-    (* assume that there aren't any removes for things added in this theory;
-       it's not rational to do that; one should add it locally only, or not
-       add it at all
-    *)
-    fun mkfrag_from thy setdeltas =
-      let fun recurse ADDs [] = ADDs
-            | recurse ADDs (ThmSetData.ADD p :: rest) = recurse (p::ADDs) rest
-            | recurse ADDs (_ :: rest) = recurse ADDs rest
-          val ADDs = recurse [] setdeltas
-            (* order of addition is flipped; see above for why this is
-               "reasonable" *)
-      in
-        simpLib.named_rewrites_with_names thy ADDs
-      end
-    fun thy_ssfrag s = get_deltas {thyname=s} |> mkfrag_from s
+(* This function is adapted from ThmSetData.sml.
 
-    fun thy_simpset s = Option.map (#1 o init_state) (DB {thyname=s})
+   It raises an exception if it can't find a theorem.
+ *)
+fun lookup_thm (s : string) : thname * thm =
+  let
+    val name =
+     case String.fields (equal #".") s of
+         [s0] => {Thy = current_theory(), Name = s}
+       | [s1,s2] => {Thy = s1, Name = s2}
+       | _ => raise mk_HOL_ERR "saveThmsLib" "lookup_thm" ("Malformed name: " ^ s)
+    fun lookup_exn {Thy,Name} = DB.fetch Thy Name
   in
-    STC_CON { name = stc_name, thy_ssfrag = thy_ssfrag, thy_simpset = thy_simpset,
-              get_ss = get_ss, export_thms = export_thms }
+    (name, lookup_exn name)
   end
 
-fun rewrite_thms_of_simpset (ss : simpset) : thm list =
-  List.concat (map simpLib.frag_rewrites (simpLib.ssfrags_of ss))
-
-(*
-(* Create a stateful theorems collection *)
-val STC_CON { name = stc_name, thy_ssfrag = thy_ssfrag, thy_simpset = thy_simpset, get_ss = get_ss, export_thms = export_thms } =
-  create_stateful_theorem_set "testStc1"
+(* The state, a pair (set of keys, map from keys to theorems) *)
+type 'key state = 'key Redblackset.set * ('key, thm) Redblackmap.dict
 
-Theorem th1:
-  T /\ (T /\ T)
-Proof
-  fs []
-QED
-
-export_thms ["th1"]
-
-fun rewrite_thms_of_simpset (ss : simpset) : thm list =
-  List.concat (map simpLib.frag_rewrites (simpLib.ssfrags_of ss))
-
-rewrite_thms_of_simpset (get_ss ())
-*)
-
-val STC_CON { name = stc_name, thy_ssfrag = thy_ssfrag, thy_simpset = thy_simpset, get_ss = get_ss, export_thms = export_thms } =
-  create_stateful_theorem_set "testStc1"
+(* Initialize a persistent map *)
+fun create_map (kf : 'key key_fns) (name : string) : 'key map_fns =
+  let
+     val { compare, get_key_from_thm } = kf
+     
+     (* Small helper *)
+     fun add_to_state (th : thm) (s, m) =
+       let
+         val k = get_key_from_thm th
+         val s = Redblackset.add (s, k)
+         val m = Redblackmap.insert (m, k, th)
+       in
+         (s, m)
+       end
+  
+     (* Persistently update the map given a delta  *)
+     fun apply_delta (delta : ThmSetData.setdelta) st =
+       case delta of
+         ThmSetData.ADD (_, th) => add_to_state th st
+       | ThmSetData.REMOVE _ =>
+         raise mk_HOL_ERR "saveThmsLib" "create_map" ("Unexpected REMOVE")
+
+     (* Initialize the dictionary *)
+     val init_state = (Redblackset.empty compare, Redblackmap.mkDict compare)
+     val {update_global_value, (* Local update *)
+          record_delta, (* Global update *)
+          get_deltas,
+          get_global_value,
+          DB = eval_ruleuction_map_by_theory,...} =
+         ThmSetData.export_with_ancestry {
+           settype = name,
+           delta_ops = {apply_to_global = apply_delta,
+                        uptodate_delta = K true,
+                        thy_finaliser = NONE,
+                        initial_value = init_state,
+                        apply_delta = apply_delta}
+         }
+
+     (* Temporarily save a theorem: update the current session, but don't
+        save the delta for the future sessions. *)
+     fun temp_save_thm (th : thm) : unit =
+       update_global_value (add_to_state th)
+
+     (* Persistently save a theorem *)
+     fun save_thm (s : string) : unit =
+       let
+         val th = lookup_thm s
+       in
+         (* Record delta saves a delta for the future sessions, but doesn't
+            update the current sessions, which is why we also call [temp_save_thm] *)
+           record_delta (ThmSetData.ADD th);
+           temp_save_thm (snd th)
+       end
+
+     (* *)
+     val get_keys = fst o get_global_value
+     val get_map = snd o get_global_value
+  in
+    { save_thm = save_thm, temp_save_thm = temp_save_thm,
+      get_keys = get_keys, get_map = get_map }
+  end
 
 end
-- 
cgit v1.2.3