summaryrefslogtreecommitdiff
path: root/backends/hol4
diff options
context:
space:
mode:
authorSon Ho2023-05-12 18:56:39 +0200
committerSon HO2023-06-04 21:54:38 +0200
commitc49fd4b6230a1f926e929f133794b6f73d338077 (patch)
tree85de1862bbfb7f79bbe0b83722ae76ef1a52dead /backends/hol4
parent2e5415364d3ad61a524980621d25713aa31f6e79 (diff)
Reimplement DefineDiv with the new fixed point operator
Diffstat (limited to '')
-rw-r--r--backends/hol4/divDefProto2Script.sml1
-rw-r--r--backends/hol4/divDefProto2TestScript.sml1245
-rw-r--r--backends/hol4/primitivesBaseTacLib.sml5
3 files changed, 1249 insertions, 2 deletions
diff --git a/backends/hol4/divDefProto2Script.sml b/backends/hol4/divDefProto2Script.sml
index 9dc43ea7..9efe835b 100644
--- a/backends/hol4/divDefProto2Script.sml
+++ b/backends/hol4/divDefProto2Script.sml
@@ -8,7 +8,6 @@ open primitivesLib
val _ = new_theory "divDefProto2"
-
(*
* Test with a general validity predicate.
*
diff --git a/backends/hol4/divDefProto2TestScript.sml b/backends/hol4/divDefProto2TestScript.sml
new file mode 100644
index 00000000..39719b65
--- /dev/null
+++ b/backends/hol4/divDefProto2TestScript.sml
@@ -0,0 +1,1245 @@
+open HolKernel boolLib bossLib Parse
+open boolTheory arithmeticTheory integerTheory intLib listTheory stringTheory
+
+open primitivesArithTheory primitivesBaseTacLib ilistTheory primitivesTheory
+open primitivesLib
+open divDefProto2Theory
+
+val _ = new_theory "divDefProto2TestScript"
+
+(*======================
+ * Example 1: nth
+ *======================*)
+Datatype:
+ list_t =
+ ListCons 't list_t
+ | ListNil
+End
+
+(* We use this version of the body to prove that the body is valid *)
+val nth_body_def = Define ‘
+ nth_body (f : (('t list_t # u32) + 't) -> (('t list_t # u32) + 't) result)
+ (x : (('t list_t # u32) + 't)) :
+ (('t list_t # u32) + 't) result =
+ (* Destruct the input. We need this to call the proper function in case
+ of mutually recursive definitions, but also to eliminate arguments
+ which correspond to the output value (the input type is the same
+ as the output type). *)
+ case x of
+ | INL x => (
+ let (ls, i) = x in
+ case ls of
+ | ListCons x tl =>
+ if u32_to_int i = (0:int)
+ then Return (INR x)
+ else
+ do
+ i0 <- u32_sub i (int_to_u32 1);
+ r <- f (INL (tl, i0));
+ (* Eliminate the invalid outputs. This is not necessary here,
+ but it is in the case of non tail call recursive calls. *)
+ case r of
+ | INL _ => Fail Failure
+ | INR i1 => Return (INR i1)
+ od
+ | ListNil => Fail Failure)
+ | INR _ => Fail Failure
+’
+
+val dbg = ref false
+fun print_dbg s = if (!dbg) then print s else ()
+
+(* Tactic which makes progress in a proof of validity by making a case
+ disjunction (we use this to explore all the paths in a function body). *)
+fun prove_valid_case_progress
+
+(*
+val (asms, g) = top_goal ()
+*)
+
+(* TODO: move *)
+(* This theorem is important to shape the goal when proving that a body
+ satifies the fixed point validity property.
+
+ Importantly: this theorem relies on the fact that errors are just transmitted
+ to the caller (in particular, without modification).
+ *)
+Theorem case_result_switch_eq:
+ (case (case x of Return y => f y | Fail e => Fail e | Diverge => Diverge) of
+ | Return y => g y
+ | Fail e => Fail e
+ | Diverge => Diverge) =
+ (case x of
+ | Return y =>
+ (case f y of
+ | Return y => g y
+ | Fail e => Fail e
+ | Diverge => Diverge)
+ | Fail e => Fail e
+ | Diverge => Diverge)
+Proof
+ Cases_on ‘x’ >> fs []
+QED
+
+(* Tactic to prove that a body is valid: perform one step. *)
+fun prove_body_is_valid_tac_step (asms, g) =
+ let
+ (* The goal has the shape:
+ {[
+ (∀g h. ... g x = ... h x) ∨
+ ∃h y. is_valid_fp_body n h ∧ ∀g. ... g x = ... od
+ ]}
+ *)
+ (* Retrieve the scrutinee in the goal (‘x’).
+ There are two cases:
+ - either the function has the shape:
+ {[
+ (λ(y,z). ...) x
+ ]}
+ in which case we need to destruct ‘x’
+ - or we have a normal ‘case ... of’
+ *)
+ val body = (lhs o snd o strip_forall o fst o dest_disj) g
+ val scrut =
+ let
+ val (app, x) = dest_comb body
+ val (app, _) = dest_comb app
+ val {Name=name, Thy=thy, Ty = _ } = dest_thy_const app
+ in
+ if thy = "pair" andalso name = "UNCURRY" then x else failwith "not a curried argument"
+ end
+ handle HOL_ERR _ => strip_all_cases_get_scrutinee body
+ (* Retrieve the first quantified continuations from the goal (‘g’) *)
+ val qc = (hd o fst o strip_forall o fst o dest_disj) g
+ (* Check if the scrutinee is a recursive call *)
+ val (scrut_app, _) = strip_comb scrut
+ val _ = print_dbg ("prove_body_is_valid_step: Scrutinee: " ^ term_to_string scrut ^ "\n")
+ (* For the recursive calls: *)
+ fun step_rec () =
+ let
+ val _ = print_dbg ("prove_body_is_valid_step: rec call\n")
+ (* We need to instantiate the ‘h’ existantially quantified function *)
+ (* First, retrieve the body of the function: it is given by the ‘Return’ branch *)
+ val (_, _, branches) = TypeBase.dest_case body
+ (* Find the branch corresponding to the return *)
+ val ret_branch = List.find (fn (pat, _) =>
+ let
+ val {Name=name, Thy=thy, Ty = _ } = (dest_thy_const o fst o strip_comb) pat
+ in
+ thy = "primitives" andalso name = "Return"
+ end) branches
+ val var = (hd o snd o strip_comb o fst o valOf) ret_branch
+ val br = (snd o valOf) ret_branch
+ (* Abstract away the input variable introduced by the pattern and the continuation ‘g’ *)
+ val h = list_mk_abs ([qc, var], br)
+ val _ = print_dbg ("prove_body_is_valid_step: h: " ^ term_to_string h ^ "\n")
+ (* Retrieve the input parameter ‘x’ *)
+ val input = (snd o dest_comb) scrut
+ val _ = print_dbg ("prove_body_is_valid_step: y: " ^ term_to_string input ^ "\n")
+ in
+ ((* Choose the right possibility (this is a recursive call) *)
+ disj2_tac >>
+ (* Instantiate the quantifiers *)
+ qexists ‘^h’ >>
+ qexists ‘^input’ >>
+ (* Unfold the predicate once *)
+ pure_once_rewrite_tac [is_valid_fp_body_def] >>
+ (* We have two subgoals:
+ - we have to prove that ‘h’ is valid
+ - we have to finish the proof of validity for the current body
+ *)
+ conj_tac >> fs [case_result_switch_eq])
+ end
+ in
+ (* If recursive call: special treatment. Otherwise, we do a simple disjunction *)
+ (if term_eq scrut_app qc then step_rec ()
+ else (Cases_on ‘^scrut’ >> fs [case_result_switch_eq])) (asms, g)
+ end
+
+(* Tactic to prove that a body is valid *)
+fun prove_body_is_valid_tac (body_def : thm option) : tactic =
+ let val body_def_thm = case body_def of SOME th => [th] | NONE => []
+ in
+ pure_once_rewrite_tac [is_valid_fp_body_def] >> gen_tac >>
+ (* Expand *)
+ fs body_def_thm >>
+ fs [bind_def, case_result_switch_eq] >>
+ (* Explore the body *)
+ rpt prove_body_is_valid_tac_step
+ end
+
+(* TODO: move *)
+val is_valid_fp_body_tm = “is_valid_fp_body”
+
+(* Prove that a body satisfies the validity condition of the fixed point *)
+fun prove_body_is_valid (body : term) : thm =
+ let
+ (* Explore the body and count the number of occurrences of nested recursive
+ calls so that we can properly instantiate the ‘N’ argument of ‘is_valid_fp_body’.
+
+ We first retrieve the name of the continuation parameter.
+ Rem.: we generated fresh names so that, for instance, the continuation name
+ doesn't collide with other names. Because of this, we don't need to look for
+ collisions when exploring the body (and in the worst case, we would cound
+ an overapproximation of the number of recursive calls, which is perfectly
+ valid).
+ *)
+ val fcont = (hd o fst o strip_abs) body
+ val fcont_name = (fst o dest_var) fcont
+ fun max x y = if x > y then x else y
+ fun count_body_rec_calls (body : term) : int =
+ case dest_term body of
+ VAR (name, _) => if name = fcont_name then 1 else 0
+ | CONST _ => 0
+ | COMB (x, y) => max (count_body_rec_calls x) (count_body_rec_calls y)
+ | LAMB (_, x) => count_body_rec_calls x
+ val num_rec_calls = count_body_rec_calls body
+
+ (* Generate the term ‘SUC (SUC ... (SUC n))’ where ‘n’ is a fresh variable.
+
+ Remark: we first prove ‘is_valid_fp_body (SUC ... n) body’ then substitue
+ ‘n’ with ‘0’ to prevent the quantity from being rewritten to a bit
+ representation, which would prevent unfolding of the ‘is_valid_fp_body’.
+ *)
+ val nvar = genvar num_ty
+ (* Rem.: we stack num_rec_calls + 1 occurrences of ‘SUC’ (and the + 1 is important) *)
+ fun mk_n i = if i = 0 then mk_suc nvar else mk_suc (mk_n (i-1))
+ val n_tm = mk_n num_rec_calls
+
+ (* Generate the lemma statement *)
+ val is_valid_tm = list_mk_icomb (is_valid_fp_body_tm, [n_tm, body])
+ val is_valid_thm = prove (is_valid_tm, prove_body_is_valid_tac NONE)
+
+ (* Replace ‘nvar’ with ‘0’ *)
+ val is_valid_thm = INST [nvar |-> zero_num_tm] is_valid_thm
+ in
+ is_valid_thm
+ end
+
+(*
+val (asms, g) = top_goal ()
+*)
+
+(* We first prove the theorem with ‘SUC (SUC n)’ where ‘n’ is a variable
+ to prevent this quantity from being rewritten to 2 *)
+Theorem nth_body_is_valid_aux:
+ is_valid_fp_body (SUC (SUC n)) nth_body
+Proof
+ prove_body_is_valid_tac (SOME nth_body_def)
+QED
+
+Theorem nth_body_is_valid:
+ is_valid_fp_body (SUC (SUC 0)) nth_body
+Proof
+ irule nth_body_is_valid_aux
+QED
+
+val nth_raw_def = Define ‘
+ nth (ls : 't list_t) (i : u32) =
+ case fix nth_body (INL (ls, i)) of
+ | Fail e => Fail e
+ | Diverge => Diverge
+ | Return r =>
+ case r of
+ | INL _ => Fail Failure
+ | INR x => Return x
+’
+
+val fix_tm = “fix”
+
+(* Generate the raw definitions by using the grouped definition body and the
+ fixed point operator *)
+fun mk_raw_defs (in_out_tys : (hol_type * hol_type) list)
+ (def_tms : term list) (body_is_valid : thm) : thm list =
+ let
+ (* Retrieve the body *)
+ val body = (List.last o snd o strip_comb o concl) body_is_valid
+
+ (* Create the term ‘fix body’ *)
+ val fixed_body = mk_icomb (fix_tm, body)
+
+ (* For every function in the group, generate the equation that we will
+ use as definition. In particular:
+ - add the properly injected input ‘x’ to ‘fix body’ (ex.: for ‘nth ls i’
+ we add the input ‘INL (ls, i)’)
+ - wrap ‘fix body x’ into a case disjunction to extract the relevant output
+
+ For instance, in the case of ‘nth ls i’:
+ {[
+ nth (ls : 't list_t) (i : u32) =
+ case fix nth_body (INL (ls, i)) of
+ | Fail e => Fail e
+ | Diverge => Diverge
+ | Return r =>
+ case r of
+ | INL _ => Fail Failure
+ | INR x => Return x
+ ]}
+ *)
+ fun mk_def_eq (i : int, def_tm : term) : term =
+ let
+ (* Retrieve the lhs of the original definition equation, and in
+ particular the inputs *)
+ val def_lhs = lhs def_tm
+ val args = (snd o strip_comb) def_lhs
+
+ (* Inject the inputs into the param type *)
+ val input = pairSyntax.list_mk_pair args
+ val input = inject_in_param_sum in_out_tys i true input
+
+ (* Compose*)
+ val def_rhs = mk_comb (fixed_body, input)
+
+ (* Wrap in the case disjunction *)
+ val def_rhs = mk_case_select_result_sum def_rhs in_out_tys i false
+
+ (* Create the equation *)
+ val def_eq_tm = mk_eq (def_lhs, def_rhs)
+ in
+ def_eq_tm
+ end
+ val raw_def_tms = map mk_def_eq (enumerate def_tms)
+
+ (* Generate the definitions *)
+ val raw_defs = map (fn tm => Define ‘^tm’) raw_def_tms
+ in
+ raw_defs
+ end
+
+(* Tactic which makes progress in a proof by making a case disjunction (we use
+ this to explore all the paths in a function body). *)
+fun case_progress (asms, g) =
+ let
+ val scrut = (strip_all_cases_get_scrutinee o lhs) g
+ in Cases_on ‘^scrut’ (asms, g) end
+
+(* Prove the final equation, that we will use as definition. *)
+fun prove_def_eq_tac
+ (current_raw_def : thm) (all_raw_defs : thm list) (is_valid : thm)
+ (body_def : thm option) : tactic =
+ let
+ val body_def_thm = case body_def of SOME th => [th] | NONE => []
+ in
+ rpt gen_tac >>
+ (* Expand the definition *)
+ pure_once_rewrite_tac [current_raw_def] >>
+ (* Use the fixed-point equality *)
+ pure_once_rewrite_left_tac [HO_MATCH_MP fix_fixed_eq is_valid] >>
+ (* Expand the body definition *)
+ pure_rewrite_tac body_def_thm >>
+ (* Expand all the definitions from the group *)
+ pure_rewrite_tac all_raw_defs >>
+ (* Explore all the paths - maybe we can be smarter, but this is fast and really easy *)
+ fs [bind_def] >>
+ rpt (case_progress >> fs [])
+ end
+
+(* Prove the final equations that we will give to the user as definitions *)
+fun prove_def_eqs (body_is_valid : thm) (def_tms : term list) (raw_defs : thm list) : thm list=
+ let
+ val defs_tgt_raw = zip def_tms raw_defs
+ (* Substitute the function variables with the constants introduced in the raw
+ definitions *)
+ fun compute_fsubst (def_tm, raw_def) : {redex: term, residue: term} =
+ let
+ val (fvar, _) = (strip_comb o lhs) def_tm
+ val fconst = (fst o strip_comb o lhs o snd o strip_forall o concl) raw_def
+ in
+ (fvar |-> fconst)
+ end
+ val fsubst = map compute_fsubst defs_tgt_raw
+ val defs_tgt_raw = map (fn (x, y) => (subst fsubst x, y)) defs_tgt_raw
+
+ fun prove_def_eq (def_tm, raw_def) : thm =
+ let
+ (* Quantify the parameters *)
+ val (_, params) = (strip_comb o lhs) def_tm
+ val def_eq_tm = list_mk_forall (params, def_tm)
+ (* Prove *)
+ val def_eq = prove (def_eq_tm, prove_def_eq_tac raw_def raw_defs body_is_valid NONE)
+ in
+ def_eq
+ end
+ val def_eqs = map prove_def_eq defs_tgt_raw
+ in
+ def_eqs
+ end
+
+Theorem nth_def:
+ ∀ls i. nth (ls : 't list_t) (i : u32) : 't result =
+ case ls of
+ | ListCons x tl =>
+ if u32_to_int i = (0:int)
+ then (Return x)
+ else
+ do
+ i0 <- u32_sub i (int_to_u32 1);
+ nth tl i0
+ od
+ | ListNil => Fail Failure
+Proof
+ prove_def_eq_tac nth_raw_def [nth_raw_def] nth_body_is_valid nth_body_def
+QED
+
+(*======================
+ * Example 2: even, odd
+ *======================*)
+
+val def_qt = ‘
+ (even (i : int) : bool result =
+ if i = 0 then Return T else odd (i - 1)) /\
+ (odd (i : int) : bool result =
+ if i = 0 then Return F else even (i - 1))
+’
+
+val result_ty = “:'a result”
+val error_ty = “:error”
+val alpha_ty = “:'a”
+val num_ty = “:num”
+
+val return_tm = “Return : 'a -> 'a result”
+val fail_tm = “Fail : error -> 'a result”
+val fail_failure_tm = “Fail Failure : 'a result”
+val diverge_tm = “Diverge : 'a result”
+
+val zero_num_tm = “0:num”
+val suc_tm = “SUC”
+
+fun mk_result (ty : hol_type) : hol_type = Type.type_subst [ alpha_ty |-> ty ] result_ty
+fun dest_result (ty : hol_type) : hol_type =
+ let
+ val {Args=out_ty, Thy=thy, Tyop=tyop} = dest_thy_type ty
+ in
+ if thy = "primitives" andalso tyop = "result" then hd out_ty
+ else failwith "dest_result: not a result"
+ end
+
+fun mk_return (x : term) : term = mk_icomb (return_tm, x)
+fun mk_fail (ty : hol_type) (e : term) : term = mk_comb (inst [ alpha_ty |-> ty ] fail_tm, e)
+fun mk_fail_failure (ty : hol_type) : term = inst [ alpha_ty |-> ty ] fail_failure_tm
+fun mk_diverge (ty : hol_type) : term = inst [ alpha_ty |-> ty ] diverge_tm
+
+fun mk_suc (n : term) = mk_comb (suc_tm, n)
+
+(*
+ *)
+
+(* **BODY GENERATION**:
+ ====================
+
+ When generating a recursive definition, we apply a fixed-point operator to
+ a function body. In case we define a group of mutually recursive definitions,
+ we generate *one* single body for the whole group of definitions. It works
+ as follows.
+
+ The input of the body is an enumeration: we start by branching over this input, and
+ every branch corresponds to one function in the mutually recursive group. Also, the
+ inputs must be grouped into tuples. Whenever we make a recursive call, we wrap the
+ input parameters into the proper variant, so as to call the proper function.
+
+ Moreover, the input of the body must have the same type as its output: we also
+ store the outputs of the functions in some variants of the enumeration.
+
+ In order to make this work, we need to shape the body so that:
+ - input values/output values are properly injected into the enumeration
+ - whenever we get an output value (which is an enumeration), we extract
+ the value from the proper variant of the enumeration
+
+ We encode the enumeration with a nested sum type, whose constructors
+ are ‘INL’ and ‘INR’.
+
+ Example:
+ ========
+ We consider the following group of mutually recursive definitions: *)
+
+val even_odd_qt = Defn.parse_quote ‘
+ (even (i : int) : bool result = if i = 0 then Return T else odd (i - 1)) /\
+ (odd (i : int) : bool result = if i = 0 then Return F else even (i - 1))
+’
+
+(* From those equations, we generate the following body: *)
+
+val even_odd_body_def = Define ‘
+ even_odd_body
+ (* The body takes a continuation - required by the fixed-point operator *)
+ (f : (int + bool + int + bool) -> (int + bool + int + bool) result)
+
+ (* The type of the input is:
+ input of even + output of even + input of odd + output of odd *)
+ (x : int + bool + int + bool) :
+
+ (* The output type is the same as the input type - this constraint
+ comes from limitations in the way we can define the fixed-point
+ operator inside the HOL logic *)
+ (int + bool + int + bool) result =
+
+ (* Case disjunction over the input, in order to figure out which
+ function from the group is actually called (even, or odd). *)
+ case x of
+ | INL i => (* Input of even *)
+ (* Body of even *)
+ if i = 0 then Return (INR (INL T))
+ else
+ (* Recursive calls are calls to the continuation f, wrapped
+ in the proper variant: here we call odd *)
+ (case f (INR (INR (INL (i - 1)))) of
+ | Fail e => Fail e
+ | Diverge => Diverge
+ | Return r =>
+ (* Extract the proper value from the enumeration: here, the
+ call is tail-call so this is not really necessary, but we
+ might need to retrieve the output of the call to odd, which
+ is a boolean, and do something more complex with it. *)
+ case r of
+ | INL _ => Fail Failure
+ | INR (INL _) => Fail Failure
+ | INR (INR (INL _)) => Fail Failure
+ | INR (INR (INR b)) => (* Extract the output of odd *)
+ (* Return: inject into the variant for the output of even *)
+ Return (INR (INL b))
+ )
+ | INR (INL _) => (* Output of even *)
+ (* We must ignore this one *)
+ Fail Failure
+ | INR (INR (INL i)) =>
+ (* Body of odd *)
+ if i = 0 then Return (INR (INR (INR F)))
+ else
+ (* Call to even *)
+ (case f (INL (i - 1)) of
+ | Fail e => Fail e
+ | Diverge => Diverge
+ | Return r =>
+ (* Extract the proper value from the enumeration *)
+ case r of
+ | INL _ => Fail Failure
+ | INR (INL b) => (* Extract the output of even *)
+ (* Return: inject into the variant for the output of odd *)
+ Return (INR (INR (INR b)))
+ | INR (INR (INL _)) => Fail Failure
+ | INR (INR (INR _)) => Fail Failure
+ )
+ | INR (INR (INR _)) => (* Output of odd *)
+ (* We must ignore this one *)
+ Fail Failure
+’
+
+(* Small helper to generate wrappers of the shape: ‘INL x’, ‘INR (INL x)’, etc.
+ Note that we should have: ‘length before_tys + 1 + length after tys >= 2’
+
+ Ex.:
+ ====
+ The enumeration has type: “: 'a + 'b + 'c + 'd”.
+ We want to generate the variant which injects “x:'c” into this enumeration.
+
+ We need to split the list of types into:
+ {[
+ before_tys = [“:'a”, “'b”]
+ tm = “x: 'c”
+ after_tys = [“:'d”]
+ ]}
+
+ The function will generate:
+ {[
+ INR (INR (INL x) : 'a + 'b + 'c + 'd
+ ]}
+
+ (* Debug *)
+ val before_tys = [“:'a”, “:'b”, “:'c”]
+ val tm = “x:'d”
+ val after_tys = [“:'e”, “:'f”]
+
+ val before_tys = [“:'a”, “:'b”, “:'c”]
+ val tm = “x:'d”
+ val after_tys = []
+
+ mk_inl_inr_wrapper before_tys tm after_tys
+ *)
+fun list_mk_inl_inr (before_tys : hol_type list) (tm : term) (after_tys : hol_type list) :
+ term =
+ let
+ val (before_tys, pat) =
+ if after_tys = []
+ then
+ let
+ val just_before_ty = List.last before_tys
+ val before_tys = List.take (before_tys, List.length before_tys - 1)
+ val pat = sumSyntax.mk_inr (tm, just_before_ty)
+ in
+ (before_tys, pat)
+ end
+ else (before_tys, sumSyntax.mk_inl (tm, sumSyntax.list_mk_sum after_tys))
+ val pat = foldr (fn (ty, pat) => sumSyntax.mk_inr (pat, ty)) pat before_tys
+ in
+ pat
+ end
+
+
+(* This function wraps a term into the proper variant of the input/output
+ sum.
+
+ Ex.:
+ ====
+ For the input of the first function, we generate: ‘INL x’
+ For the output of the first function, we generate: ‘INR (INL x)’
+ For the input of the 2nd function, we generate: ‘INR (INR (INL x))’
+ etc.
+
+ If ‘is_input’ is true: we are wrapping an input. Otherwise we are wrapping
+ an output.
+
+ (* Debug *)
+ val tys = [(“:'a”, “:'b”), (“:'c”, “:'d”), (“:'e”, “:'f”)]
+ val j = 1
+ val tm = “x:'c”
+ val tm = “y:'d”
+ val is_input = true
+ *)
+fun inject_in_param_sum (tys : (hol_type * hol_type) list) (j : int) (is_input : bool)
+ (tm : term) : term =
+ let
+ fun flatten ls = List.concat (map (fn (x, y) => [x, y]) ls)
+ val before_tys = flatten (List.take (tys, j))
+ val (input_ty, output_ty) = List.nth (tys, j)
+ val after_tys = flatten (List.drop (tys, j + 1))
+ val (before_tys, after_tys) =
+ if is_input then (before_tys, output_ty :: after_tys)
+ else (before_tys @ [input_ty], after_tys)
+ in
+ list_mk_inl_inr before_tys tm after_tys
+ end
+
+(* Remark: the order of the branches when creating matches is important.
+ For instance, in the case of ‘result’ it must be: ‘Return’, ‘Fail’, ‘Diverge’.
+
+ For the purpose of stability and maintainability, we introduce this small helper
+ which reorders the cases in a pattern before actually creating the case
+ expression.
+ *)
+fun unordered_mk_case (scrut: term, pats: (term * term) list) : term =
+ let
+ (* Retrieve the constructors *)
+ val cl = TypeBase.constructors_of (type_of scrut)
+ (* Retrieve the names of the constructors *)
+ val names = map (fst o dest_const) cl
+ (* Use those to reorder the patterns *)
+ fun is_pat (name : string) (pat, _) =
+ let
+ val app = (fst o strip_comb) pat
+ val app_name = (fst o dest_const) app
+ in
+ app_name = name
+ end
+ val pats = map (fn name => valOf (List.find (is_pat name) pats)) names
+ in
+ (* Create the case *)
+ TypeBase.mk_case (scrut, pats)
+ end
+
+(* Wrap a term of type “:'a result” into a ‘case of’ which matches over
+ the result.
+
+ Ex.:
+ ====
+ {[
+ f x
+
+ ~~>
+
+ case f x of
+ | Fail e => Fail e
+ | Diverge => Diverge
+ | Return y => ... (* The branch content is generated by the continuation *)
+ ]}
+
+ ‘gen_ret_branch’ is a *continuation* which generates the content of the
+ ‘Return’ branch (i.e., the content of the ‘...’ in the example above).
+ It receives as input the value contained by the ‘Return’ (i.e., the variable
+ ‘y’ in the example above).
+
+ Remark.: the type of the term generated by ‘gen_ret_branch’ must have
+ the type ‘result’, but it can change the content of the result (i.e.,
+ if ‘scrut’ has type ‘:'a result’, we can change the type of the wrapped
+ expression to ‘:'b result’).
+
+ (* Debug *)
+ val scrut = “x: int result”
+ fun gen_ret_branch x = mk_return x
+
+ val scrut = “x: int result”
+ fun gen_ret_branch _ = “Return T”
+
+ mk_result_case scrut gen_ret_branch
+ *)
+fun mk_result_case (scrut : term) (gen_ret_branch : term -> term) : term =
+ let
+ val scrut_ty = dest_result (type_of scrut)
+ (* Return branch *)
+ val ret_var = genvar scrut_ty
+ val ret_pat = mk_return ret_var
+ val ret_br = gen_ret_branch ret_var
+ val ret_ty = dest_result (type_of ret_br)
+ (* Failure branch *)
+ val fail_var = genvar error_ty
+ val fail_pat = mk_fail scrut_ty fail_var
+ val fail_br = mk_fail ret_ty fail_var
+ (* Diverge branch *)
+ val div_pat = mk_diverge scrut_ty
+ val div_br = mk_diverge ret_ty
+ in
+ unordered_mk_case (scrut, [(ret_pat, ret_br), (fail_pat, fail_br), (div_pat, div_br)])
+ end
+
+(* Generate a ‘case ... of’ over a sum type.
+
+ Ex.:
+ ====
+ If the scrutinee is: “x : 'a + 'b + 'c” (i.e., the tys list is: [“:'a”, “:b”, “:c”]),
+ we generate:
+
+ {[
+ case x of
+ | INL y0 => ... (* Branch of index 0 *)
+ | INR (INL y1) => ... (* Branch of index 1 *)
+ | INR (INR (INL y2)) => ... (* Branch of index 2 *)
+ | INR (INR (INR y3)) => ... (* Branch of index 3 *)
+ ]}
+
+ The content of the branches is generated by the ‘gen_branch’ continuation,
+ which receives as input the index of the branch as well as the variable
+ introduced by the pattern (in the example above: ‘y0’ for the branch 0,
+ ‘y1’ for the branch 1, etc.)
+
+ (* Debug *)
+ val tys = [“:'a”, “:'b”]
+ val scrut = mk_var ("x", sumSyntax.list_mk_sum tys)
+ fun gen_branch i (x : term) = “F”
+
+ val tys = [“:'a”, “:'b”, “:'c”, “:'d”]
+ val scrut = mk_var ("x", sumSyntax.list_mk_sum tys)
+ fun gen_branch i (x : term) = if type_of x = “:'c” then mk_return x else mk_fail_failure “:'c”
+
+ list_mk_sum_case scrut tys gen_branch
+ *)
+(* For debugging *)
+val list_mk_sum_case_case = ref (“T”, [] : (term * term) list)
+(*
+val (scrut, [(pat1, br1), (pat2, br2)]) = !list_mk_sum_case_case
+*)
+fun list_mk_sum_case (scrut : term) (tys : hol_type list)
+ (gen_branch : int -> term -> term) : term =
+ let
+ (* Create the cases. Note that without sugar, the match actually looks like this:
+ {[
+ case x of
+ | INL y0 => ... (* Branch of index 0 *)
+ | INR x1
+ case x1 of
+ | INL y1 => ... (* Branch of index 1 *)
+ | INR x2 =>
+ case x2 of
+ | INL y2 => ... (* Branch of index 2 *)
+ | INR y3 => ... (* Branch of index 3 *)
+ ]}
+ *)
+ fun create_case (j : int) (scrut : term) (tys : hol_type list) : term =
+ let
+ val _ = print_dbg ("list_mk_sum_case: " ^
+ String.concatWith ", " (map type_to_string tys) ^ "\n")
+ in
+ case tys of
+ [] => failwith "tys is too short"
+ | [ ty ] =>
+ (* Last element: no match to perform *)
+ gen_branch j scrut
+ | ty1 :: tys =>
+ (* Not last: we create a pattern:
+ {[
+ case scrut of
+ | INL pat_var1 => ... (* Branch of index i *)
+ | INR pat_var2 =>
+ ... (* Generate this term recursively *)
+ ]}
+ *)
+ let
+ (* INL branch *)
+ val after_ty = sumSyntax.list_mk_sum tys
+ val pat_var1 = genvar ty1
+ val pat1 = sumSyntax.mk_inl (pat_var1, after_ty)
+ val br1 = gen_branch j pat_var1
+ (* INR branch *)
+ val pat_var2 = genvar after_ty
+ val pat2 = sumSyntax.mk_inr (pat_var2, ty1)
+ val br2 = create_case (j+1) pat_var2 tys
+ val _ = print_dbg ("list_mk_sum_case: assembling:\n" ^
+ term_to_string scrut ^ ",\n" ^
+ "[(" ^ term_to_string pat1 ^ ",\n " ^ term_to_string br1 ^ "),\n\n" ^
+ " (" ^ term_to_string pat2 ^ ",\n " ^ term_to_string br2 ^ ")]\n\n")
+ val case_elems = (scrut, [(pat1, br1), (pat2, br2)])
+ val _ = list_mk_sum_case_case := case_elems
+ in
+ (* Put everything together *)
+ TypeBase.mk_case case_elems
+ end
+ end
+ in
+ create_case 0 scrut tys
+ end
+
+(* Generate a ‘case ... of’ to select the input/output of the ith variant of
+ the param enumeration.
+
+ Ex.:
+ ====
+ There are two functions in the group, and we select the input of the function of index 1:
+ {[
+ case x of
+ | INL _ => Fail Failure (* Input of function of index 0 *)
+ | INR (INL _) => Fail Failure (* Output of function of index 0 *)
+ | INR (INR (INL y)) => Return y (* Input of the function of index 1: select this one *)
+ | INR (INR (INR _)) => Fail Failure (* Output of the function of index 1 *)
+ ]}
+
+ (* Debug *)
+ val tys = [(“:'a”, “:'b”)]
+ val scrut = “x : 'a + 'b”
+ val fi = 0
+ val is_input = true
+
+ val tys = [(“:'a”, “:'b”), (“:'c”, “:'d”)]
+ val scrut = “x : 'a + 'b + 'c + 'd”
+ val fi = 1
+ val is_input = false
+
+ val scrut = mk_var ("x", sumSyntax.list_mk_sum (flatten tys))
+
+ list_mk_case_select scrut tys fi is_input
+ *)
+fun list_mk_case_sum_select (scrut : term) (tys : (hol_type * hol_type) list)
+ (fi : int) (is_input : bool) : term =
+ let
+ (* The index of the element in the enumeration that we will select *)
+ val i = 2 * fi + (if is_input then 0 else 1)
+ (* Flatten the types and numerotate them *)
+ fun flatten ls = List.concat (map (fn (x, y) => [x, y]) ls)
+ val tys = flatten tys
+ (* Get the return type *)
+ val ret_ty = List.nth (tys, i)
+ (* The continuation which will generate the content of the branches *)
+ fun gen_branch j var = if j = i then mk_return var else mk_fail_failure ret_ty
+ in
+ (* Generate the ‘case ... of’ *)
+ list_mk_sum_case scrut tys gen_branch
+ end
+
+(* Generate a ‘case ... of’ to select the input/output of the ith variant of
+ the param enumeration.
+
+ Ex.:
+ ====
+ There are two functions in the group, and we select the input of the function of index 1:
+ {[
+ case x of
+ | Fail e => Fail e
+ | Diverge => Diverge
+ | Return r =>
+ case r of
+ | INL _ => Fail Failure (* Input of function of index 0 *)
+ | INR (INL _) => Fail Failure (* Output of function of index 0 *)
+ | INR (INR (INL y)) => Return y (* Input of the function of index 1: select this one *)
+ | INR (INR (INR _)) => Fail Failure (* Output of the function of index 1 *)
+ ]}
+ *)
+fun mk_case_select_result_sum (scrut : term) (tys : (hol_type * hol_type) list)
+ (fi : int) (is_input : bool) : term =
+ (* We match over the result, then over the enumeration *)
+ mk_result_case scrut (fn x => list_mk_case_sum_select x tys fi is_input)
+
+(*
+val scrut = call
+val tys = in_out_tys
+val is_input = false
+val call = mk_case_select_result_sum call in_out_tys fi false
+*)
+
+(* TODO: move *)
+fun enumerate (ls : 'a list) : (int * 'a) list =
+ zip (List.tabulate (List.length ls, fn i => i)) ls
+
+(* Generate a body for the fixed-point operator from a quoted group of mutually
+ recursive definitions.
+
+ See TODO for detailed explanations: from the quoted equations for ‘nth’
+ (or for [‘even’, ‘odd’]) we generate the body ‘nth_body’ (or ‘even_odd_body’,
+ respectively).
+ *)
+fun mk_body (fnames : string list) (in_out_tys : (hol_type * hol_type) list)
+ (def_tms : term list) : term =
+ let
+ val fnames_set = Redblackset.fromList String.compare fnames
+
+ (* Compute a map from function name to function index *)
+ val fnames_map = Redblackmap.fromList String.compare
+ (map (fn (x, y) => (y, x)) (enumerate fnames))
+
+ (* Compute the input/output type, that we dub the "parameter type" *)
+ fun flatten ls = List.concat (map (fn (x, y) => [x, y]) ls)
+ val param_type = sumSyntax.list_mk_sum (flatten in_out_tys)
+
+ (* Introduce a variable for the confinuation *)
+ val fcont = genvar (param_type --> mk_result param_type)
+
+ (* In the function equations, replace all the recursive calls with calls to the continuation.
+
+ When replacing a recursive call, we have to do two things:
+ - we need to inject the input parameters into the parameter type
+ Ex.:
+ - ‘nth tl i’ becomes ‘f (INL (tl, i))’ where ‘f’ is the continuation
+ - ‘even i’ becomes ‘f (INL i)’ where ‘f’ is the continuation
+ - we need to wrap the the call to the continuation into a ‘case ... of’
+ to extract its output (we need to make sure that the transformation
+ preserves the type of the expression!)
+ Ex.: ‘nth tl i’ becomes:
+ {[
+ case f (INL (tl, i)) of
+ | Fail e => Fail e
+ | Diverge => Diverge
+ | Return r =>
+ case r of
+ | INL _ => Fail Failure
+ | INR x => Return (INR x)
+ ]}
+ *)
+ (* For debugging *)
+ val replace_rec_calls_rec_call_tm = ref “T”
+ fun replace_rec_calls (fnames_set : string Redblackset.set) (tm : term) : term =
+ let
+ val _ = print_dbg ("replace_rec_calls: original expression:\n" ^
+ term_to_string tm ^ "\n\n")
+ val ntm =
+ case dest_term tm of
+ VAR (name, ty) =>
+ (* Check that this is not one of the functions in the group - remark:
+ we could handle that by introducing lambdas.
+ *)
+ if Redblackset.member (fnames_set, name)
+ then failwith ("mk_body: not well-formed definition: found " ^ name ^
+ " in an improper position")
+ else tm
+ | CONST _ => tm
+ | LAMB (x, tm) =>
+ let
+ (* The variable might shadow one of the functions *)
+ val fnames_set = Redblackset.delete (fnames_set, (fst o dest_var) x)
+ (* Update the term in the lambda *)
+ val tm = replace_rec_calls fnames_set tm
+ in
+ (* Reconstruct *)
+ mk_abs (x, tm)
+ end
+ | COMB (_, _) =>
+ let
+ (* Completely destruct the application, check if this is a recursive call *)
+ val (app, args) = strip_comb tm
+ val is_rec_call = Redblackset.member (fnames_set, (fst o dest_var) app)
+ handle HOL_ERR _ => false
+ (* Whatever the case, apply the transformation to all the inputs *)
+ val args = map (replace_rec_calls fnames_set) args
+ in
+ (* If this is not a recursive call: apply the transformation to all the
+ terms. Otherwise, replace. *)
+ if not is_rec_call then list_mk_comb (replace_rec_calls fnames_set app, args)
+ else
+ (* Rec call: replace *)
+ let
+ val _ = replace_rec_calls_rec_call_tm := tm
+ (* First, find the index of the function *)
+ val fname = (fst o dest_var) app
+ val fi = Redblackmap.find (fnames_map, fname)
+ (* Inject the input values into the param type *)
+ val input = pairSyntax.list_mk_pair args
+ val input = inject_in_param_sum in_out_tys fi true input
+ (* Create the recursive call *)
+ val call = mk_comb (fcont, input)
+ (* Wrap the call into a ‘case ... of’ to extract the output *)
+ val call = mk_case_select_result_sum call in_out_tys fi false
+ in
+ (* Return *)
+ call
+ end
+ end
+ val _ = print_dbg ("replace_rec_calls: new expression:\n" ^ term_to_string ntm ^ "\n\n")
+ in
+ ntm
+ end
+ handle HOL_ERR e =>
+ let
+ val _ = print_dbg ("replace_rec_calls: failed on:\n" ^ term_to_string tm ^ "\n\n")
+ in
+ raise (HOL_ERR e)
+ end
+ fun replace_rec_calls_in_eq (eq : term) : term =
+ let
+ val (l, r) = dest_eq eq
+ in
+ mk_eq (l, replace_rec_calls fnames_set r)
+ end
+ val def_tms_with_fcont = map replace_rec_calls_in_eq def_tms
+
+ (* Wrap all the function bodies to inject their result into the param type.
+
+ We collect the function inputs at the same time, because they will be
+ grouped into a tuple that we will have to deconstruct.
+ *)
+ fun inject_body_to_enums (i : int, def_eq : term) : term list * term =
+ let
+ val (l, body) = dest_eq def_eq
+ val (_, args) = strip_comb l
+ (* We have the deconstruct the result, then, in the ‘Return’ branch,
+ properly wrap the returned value *)
+ val body = mk_result_case body (fn x => mk_return (inject_in_param_sum in_out_tys i false x))
+ in
+ (args, body)
+ end
+ val def_tms_inject = map inject_body_to_enums (enumerate def_tms_with_fcont)
+
+ (* Currify the body inputs.
+
+ For instance, if the body has inputs: ‘x’, ‘y’; we return the following:
+ {[
+ (‘z’, ‘case z of (x, y) => ... (* body *) ’)
+ ]}
+ where ‘z’ is fresh.
+
+ We return: (curried input, body).
+
+ (* Debug *)
+ val body = “(x:'a, y:'b, z:'c)”
+ val args = [“x:'a”, “y:'b”, “z:'c”]
+ currify_body_inputs (args, body)
+ *)
+ fun currify_body_inputs (args : term list, body : term) : term * term =
+ let
+ fun mk_curry (args : term list) (body : term) : term * term =
+ case args of
+ [] => failwith "no inputs"
+ | [x] => (x, body)
+ | x1 :: args =>
+ let
+ val (x2, body) = mk_curry args body
+ val scrut = genvar (pairSyntax.list_mk_prod (map type_of (x1 :: args)))
+ val pat = pairSyntax.mk_pair (x1, x2)
+ val br = body
+ in
+ (scrut, TypeBase.mk_case (scrut, [(pat, br)]))
+ end
+ in
+ mk_curry args body
+ end
+ val def_tms_currified = map currify_body_inputs def_tms_inject
+
+ (* Group all the functions into a single body, with an outer ‘case .. of’
+ which selects the appropriate body depending on the input *)
+ val param_ty = sumSyntax.list_mk_sum (flatten in_out_tys)
+ val input = genvar param_ty
+ fun mk_mut_rec_body_branch (i : int) (patvar : term) : term =
+ (* Case disjunction on whether the branch is for an input value (in
+ which case we call the proper body) or an output value (in which
+ case we return ‘Fail ...’ *)
+ if i mod 2 = 0 then
+ let
+ val fi = i div 2
+ val (x, def_tm) = List.nth (def_tms_currified, fi)
+ (* The variable in the pattern and the variable expected by the
+ body may not be the same: we introduce a let binding *)
+ val def_tm = mk_let (mk_abs (x, def_tm), patvar)
+ in
+ def_tm
+ end
+ else
+ (* Output value: fail *)
+ mk_fail_failure param_ty
+ val mut_rec_body = list_mk_sum_case input (flatten in_out_tys) mk_mut_rec_body_branch
+
+
+ (* Abstract away the parameters to produce the final body of the fixed point *)
+ val mut_rec_body = list_mk_abs ([fcont, input], mut_rec_body)
+ in
+ mut_rec_body
+ end
+
+(* For explanations about the different steps, see TODO *)
+fun DefineDiv (def_qt : term quotation) =
+ let
+ (* Parse the definitions.
+
+ Example:
+ {[
+ (even (i : int) : bool result = if i = 0 then Return T else odd (i - 1)) /\
+ (odd (i : int) : bool result = if i = 0 then Return F else even (i - 1))
+ ]}
+ *)
+ val def_tms = (strip_conj o list_mk_conj o rev) (Defn.parse_quote def_qt)
+
+ (* Compute the names and the input/output types of the functions *)
+ fun compute_names_in_out_tys (tm : term) : string * (hol_type * hol_type) =
+ let
+ val app = lhs tm
+ val name = (fst o dest_var o fst o strip_comb) app
+ val out_ty = dest_result (type_of app)
+ val input_tys = pairSyntax.list_mk_prod (map type_of ((snd o strip_comb) app))
+ in
+ (name, (input_tys, out_ty))
+ end
+ val (fnames, in_out_tys) = unzip (map compute_names_in_out_tys def_tms)
+
+ (* Generate the body.
+
+ See the comments at the beginning of the file (lookup "BODY GENERATION").
+ *)
+ val body = mk_body fnames in_out_tys def_tms
+
+ (* Prove that the body satisfies the validity property required by the fixed point *)
+ val body_is_valid = prove_body_is_valid body
+
+ (* Generate the definitions for the various functions by using the fixed point
+ and the body. *)
+ val raw_defs = mk_raw_defs in_out_tys def_tms body_is_valid
+
+ (* Prove the final equations *)
+ val def_eqs = prove_def_eqs body_is_valid def_tms raw_defs
+ in
+ def_eqs
+ end
+
+val [even_def, odd_def] = DefineDiv ‘
+ (even (i : int) : bool result =
+ if i = 0 then Return T else odd (i - 1)) /\
+ (odd (i : int) : bool result =
+ if i = 0 then Return F else even (i - 1))
+’
+
+val [nth_def] = DefineDiv ‘
+ nth (ls : 't list_t) (i : u32) : 't result =
+ case ls of
+ | ListCons x tl =>
+ if u32_to_int i = (0:int)
+ then (Return x)
+ else
+ do
+ i0 <- u32_sub i (int_to_u32 1);
+ nth tl i0
+ od
+ | ListNil => Fail Failure
+’
+
+val even_odd_body_def = Define ‘
+ even_odd_body
+ (* The body takes a continuation - required by the fixed-point operator *)
+ (f : (int + bool + int + bool) -> (int + bool + int + bool) result)
+ (* The type of the input/output is:
+ input of even + output of even + input of odd + output of odd
+ *)
+ (x : int + bool + int + bool) :
+ (int + bool + int + bool) result =
+ (* Case disjunction over the input, in order to figure out which
+ function from the group is actually called (even , or odd). *)
+ case x of
+ | INL i => (* Input of even *)
+ (* Body of even *)
+ if i = 0 then Return (INR (INL T))
+ else
+ (* Recursive calls are calls to the continuation f, wrapped
+ in the proper variant: here we call odd *)
+ (case f (INR (INR (INL (i - 1)))) of
+ | Fail e => Fail e
+ | Diverge => Diverge
+ | Return r =>
+ (* Eliminate the unwanted results *)
+ case r of
+ | INL _ => Fail Failure
+ | INR (INL _) => Fail Failure
+ | INR (INR (INL _)) => Fail Failure
+ | INR (INR (INR b)) => (* Extract the output of odd *)
+ (* Inject into the variant for the output of even *)
+ Return (INR (INL b))
+ )
+ | INR (INL _) => (* Output of even *)
+ (* We must ignore this one *)
+ Fail Failure
+ | INR (INR (INL i)) =>
+ (* Body of odd *)
+ if i = 0 then Return (INR (INR (INR F)))
+ else
+ (* Call to even *)
+ (case f (INL (i - 1)) of
+ | Fail e => Fail e
+ | Diverge => Diverge
+ | Return r =>
+ (* Eliminate the unwanted results *)
+ case r of
+ | INL _ => Fail Failure
+ | INR (INL b) => (* Extract the output of even *)
+ (* Inject into the variant for the output of odd *)
+ Return (INR (INR (INR b)))
+ | INR (INR (INL _)) => Fail Failure
+ | INR (INR (INR _)) => Fail Failure
+ )
+ | INR (INR (INR _)) => (* Output of odd *)
+ (* We must ignore this one *)
+ Fail Failure
+’
+
+Theorem even_odd_body_is_valid_aux:
+ is_valid_fp_body (SUC (SUC n)) even_odd_body
+Proof
+ prove_body_is_valid_tac (SOME even_odd_body_def)
+QED
+
+Theorem even_odd_body_is_valid:
+ is_valid_fp_body (SUC (SUC 0)) even_odd_body
+Proof
+ irule even_odd_body_is_valid_aux
+QED
+
+val even_raw_def = Define ‘
+ even (i : int) =
+ case fix even_odd_body (INL i) of
+ | Fail e => Fail e
+ | Diverge => Diverge
+ | Return r =>
+ case r of
+ | INL _ => Fail Failure
+ | INR (INL b) => Return b
+ | INR (INR (INL _)) => Fail Failure
+ | INR (INR (INR _)) => Fail Failure
+’
+
+val odd_raw_def = Define ‘
+ odd (i : int) =
+ case fix even_odd_body (INR (INR (INL i))) of
+ | Fail e => Fail e
+ | Diverge => Diverge
+ | Return r =>
+ case r of
+ | INL _ => Fail Failure
+ | INR (INL b) => Fail Failure
+ | INR (INR (INL _)) => Fail Failure
+ | INR (INR (INR b)) => Return b
+’
+
+Theorem even_def:
+ ∀i. even (i : int) : bool result =
+ if i = 0 then Return T else odd (i - 1)
+Proof
+ prove_def_eq_tac even_raw_def [even_raw_def, odd_raw_def] even_odd_body_is_valid even_odd_body_def
+QED
+
+Theorem odd_def:
+ ∀i. odd (i : int) : bool result =
+ if i = 0 then Return F else even (i - 1)
+Proof
+ prove_def_eq_tac odd_raw_def [even_raw_def, odd_raw_def] even_odd_body_is_valid even_odd_body_def
+QED
+
+val _ = export_theory ()
diff --git a/backends/hol4/primitivesBaseTacLib.sml b/backends/hol4/primitivesBaseTacLib.sml
index fe87e894..78822abe 100644
--- a/backends/hol4/primitivesBaseTacLib.sml
+++ b/backends/hol4/primitivesBaseTacLib.sml
@@ -47,11 +47,14 @@ val pat_undisch_tac = Q.PAT_UNDISCH_TAC
val equiv_is_imp = prove (“∀x y. ((x ⇒ y) ∧ (y ⇒ x)) ⇒ (x ⇔ y)”, metis_tac [])
val equiv_tac = irule equiv_is_imp >> conj_tac
+(* Rewrite the goal once, and on the left part of the goal seen as an application *)
+fun pure_once_rewrite_left_tac ths =
+ CONV_TAC (PATH_CONV "l" (PURE_ONCE_REWRITE_CONV ths))
+
(* Dependent rewrites *)
val dep_pure_once_rewrite_tac = dep_rewrite.DEP_PURE_ONCE_REWRITE_TAC
val dep_pure_rewrite_tac = dep_rewrite.DEP_PURE_REWRITE_TAC
-
(* Add a list of theorems in the assumptions *)
fun assume_tacl (thms : thm list) : tactic =
let