summaryrefslogtreecommitdiff
path: root/backends/hol4/divDefLib.sml
diff options
context:
space:
mode:
Diffstat (limited to 'backends/hol4/divDefLib.sml')
-rw-r--r--backends/hol4/divDefLib.sml1789
1 files changed, 740 insertions, 1049 deletions
diff --git a/backends/hol4/divDefLib.sml b/backends/hol4/divDefLib.sml
index 59c1edaf..3e2d7c04 100644
--- a/backends/hol4/divDefLib.sml
+++ b/backends/hol4/divDefLib.sml
@@ -1,5 +1,3 @@
-(* This file implements utilities to define potentially diverging functions *)
-
structure divDefLib :> divDefLib =
struct
@@ -8,1196 +6,889 @@ open boolTheory arithmeticTheory integerTheory intLib listTheory stringTheory
open primitivesArithTheory primitivesBaseTacLib ilistTheory primitivesTheory
open primitivesLib
+open divDefTheory
-val case_result_same_eq = prove (
- “!(r : 'a result).
- (case r of
- Return x => Return x
- | Fail e => Fail e
- | Diverge => Diverge) = r”,
- rw [] >> CASE_TAC)
-
-(*
-val ty = id_ty
-strip_arrows ty
-*)
+val dbg = ref false
+fun print_dbg s = if (!dbg) then print s else ()
-(* TODO: move *)
-fun list_mk_arrow (tys : hol_type list) (ret_ty : hol_type) : hol_type =
- foldr (fn (ty, aty) => ty --> aty) ret_ty tys
+val result_ty = “:'a result”
+val error_ty = “:error”
+val alpha_ty = “:'a”
+val num_ty = “:num”
-(* TODO: move *)
-fun strip_arrows (ty : hol_type) : hol_type list * hol_type =
- let
- val (ty0, ty1) = dom_rng ty
- val (tys, ret) = strip_arrows ty1
- in
- (ty0::tys, ret)
- end
- handle HOL_ERR _ => ([], ty)
+val zero_num_tm = “0:num”
+val suc_tm = “SUC”
-(* Small utilities *)
-val current_goal : term option ref = ref NONE
+val return_tm = “Return : 'a -> 'a result”
+val fail_tm = “Fail : error -> 'a result”
+val fail_failure_tm = “Fail Failure : 'a result”
+val diverge_tm = “Diverge : 'a result”
-(* Save a goal in {!current_goal} then prove it.
+val fix_tm = “fix”
+val is_valid_fp_body_tm = “is_valid_fp_body”
- This way if the proof fails we can easily retrieve the goal for debugging
- purposes.
- *)
-fun save_goal_and_prove (g, tac) : thm =
+fun mk_result (ty : hol_type) : hol_type = Type.type_subst [ alpha_ty |-> ty ] result_ty
+fun dest_result (ty : hol_type) : hol_type =
let
- val _ = current_goal := SOME g
+ val {Args=out_ty, Thy=thy, Tyop=tyop} = dest_thy_type ty
in
- prove (g, tac)
+ if thy = "primitives" andalso tyop = "result" then hd out_ty
+ else failwith "dest_result: not a result"
end
-
-
-(*val def_qt = ‘
-(nth_fuel (n : num) (ls : 't list_t) (i : u32) : 't result =
- case n of
- | 0 => Loop
- | SUC n =>
- do case ls of
- | ListCons x tl =>
- if u32_to_int i = (0:int)
- then Return x
- else
- do
- i0 <- u32_sub i (int_to_u32 1);
- nth_fuel n tl i0
- od
- | ListNil =>
- Fail Failure
- od)
-’*)
-
-val num_zero_tm = “0:num”
-val num_suc_tm = “SUC: num -> num”
-val num_ty = “:num”
-
-val fuel_def_suffix = "___fuel" (* TODO: name collisions *)
-val fuel_var_name = "$n" (* TODO: name collisions *)
-val fuel_var = mk_var (fuel_var_name, num_ty)
-val fuel_var0 = fuel_var
-val fuel_var1 = mk_var ("$m", “:num”) (* TODO: name collisions *)
-val fuel_vars_le = “^fuel_var0 <= ^fuel_var1”
-val fuel_predicate_suffix = "___P" (* TODO: name collisions *)
-val expand_suffix = "___E" (* TODO: name collisions *)
+fun mk_return (x : term) : term = mk_icomb (return_tm, x)
+fun mk_fail (ty : hol_type) (e : term) : term = mk_comb (inst [ alpha_ty |-> ty ] fail_tm, e)
+fun mk_fail_failure (ty : hol_type) : term = inst [ alpha_ty |-> ty ] fail_failure_tm
+fun mk_diverge (ty : hol_type) : term = inst [ alpha_ty |-> ty ] diverge_tm
-val bool_ty = “:bool”
+fun mk_suc (n : term) = mk_comb (suc_tm, n)
-val alpha_tyvar : hol_type = “:'a”
-val beta_tyvar : hol_type = “:'b”
-
-val is_diverge_tm = “is_diverge: 'a result -> bool”
-val diverge_tm = “Diverge : 'a result”
+fun enumerate (ls : 'a list) : (int * 'a) list =
+ zip (List.tabulate (List.length ls, fn i => i)) ls
-val least_tm = “$LEAST”
-val le_tm = (fst o strip_comb) “x:num <= y:num”
-val true_tm = “T”
-val false_tm = “F”
+(*=============================================================================*
+ *
+ * Generate the (non-recursive) body to give to the fixed-point operator
+ *
+ * ============================================================================*)
-val measure_tm = “measure: ('a -> num) -> 'a -> 'a -> bool”
-
-fun mk_diverge_tm (ty : hol_type) : term =
- let
- val diverge_ty = mk_thy_type {Thy="primitives", Tyop="result", Args = [ty] }
- val diverge_tm = mk_thy_const { Thy="primitives", Name="Diverge", Ty=diverge_ty }
- in
- diverge_tm
- end
+(* Small helper to generate wrappers of the shape: ‘INL x’, ‘INR (INL x)’, etc.
+ Note that we should have: ‘length before_tys + 1 + length after tys >= 2’
-(* Small utility: we sometimes need to generate a termination measure for
- the fuel definitions.
+ Ex.:
+ ====
+ The enumeration has type: “: 'a + 'b + 'c + 'd”.
+ We want to generate the variant which injects “x:'c” into this enumeration.
- We derive a measure for a type which is simply the sum of the tuples
- of the input types of the functions.
-
- For instance, for even and odd we have:
+ We need to split the list of types into:
{[
- even___fuel : num -> int -> bool result
- odd___fuel : num -> int -> bool result
+ before_tys = [“:'a”, “'b”]
+ tm = “x: 'c”
+ after_tys = [“:'d”]
]}
- So the type would be:
+ The function will generate:
{[
- (num # int) + (num # int)
+ INR (INR (INL x) : 'a + 'b + 'c + 'd
]}
- Note that generally speaking we expect a type of the shape (the “:num”
- on the left is for the fuel):
- {[
- (num # ...) + (num # ...) + ... + (num # ...)
- ]}
+ (* Debug *)
+ val before_tys = [“:'a”, “:'b”, “:'c”]
+ val tm = “x:'d”
+ val after_tys = [“:'e”, “:'f”]
- The decreasing measure is simply given by a function which matches over
- its argument to return the fuel, whatever the case.
+ val before_tys = [“:'a”, “:'b”, “:'c”]
+ val tm = “x:'d”
+ val after_tys = []
+
+ mk_inl_inr_wrapper before_tys tm after_tys
*)
-fun mk_termination_measure_from_ty (ty : hol_type) : term =
+fun list_mk_inl_inr (before_tys : hol_type list) (tm : term) (after_tys : hol_type list) :
+ term =
let
- val dtys = map pairSyntax.strip_prod (sumSyntax.strip_sum ty)
- (* For every tuple, create a match to extract the num *)
- fun mk_case_of_tuple (tys : hol_type list) : (term * term) =
- case tys of
- [] => failwith "mk_termination_measure_from_ty: empty list of types"
- | [num_ty] =>
- (* No need for a case *)
- let val var = genvar num_ty in (var, var) end
- | num_ty :: rem_tys =>
+ val (before_tys, pat) =
+ if after_tys = []
+ then
let
- val scrut_var = genvar (pairSyntax.list_mk_prod tys)
- val var = genvar num_ty
- val rem_var = genvar (pairSyntax.list_mk_prod rem_tys)
- val pats = [(pairSyntax.mk_pair (var, rem_var), var)]
- val case_tm = TypeBase.mk_case (scrut_var, pats)
+ val just_before_ty = List.last before_tys
+ val before_tys = List.take (before_tys, List.length before_tys - 1)
+ val pat = sumSyntax.mk_inr (tm, just_before_ty)
in
- (scrut_var, case_tm)
+ (before_tys, pat)
end
- val tuple_cases = map mk_case_of_tuple dtys
-
- (* For every sum, create a match to extract one of the tuples *)
- fun mk_sum_case ((tuple_var, tuple_case), (nvar, case_end)) =
- let
- val left_pat = sumSyntax.mk_inl (tuple_var, type_of nvar)
- val right_pat = sumSyntax.mk_inr (nvar, type_of tuple_var)
- val scrut = genvar (sumSyntax.mk_sum (type_of tuple_var, type_of nvar))
- val pats = [(left_pat, tuple_case), (right_pat, case_end)]
- val case_tm = TypeBase.mk_case (scrut, pats)
- in
- (scrut, case_tm)
- end
- val tuple_cases = rev tuple_cases
- val (nvar, case_end) = hd tuple_cases
- val tuple_cases = tl tuple_cases
- val (scrut, case_tm) = foldl mk_sum_case (nvar, case_end) tuple_cases
-
- (* Create the function *)
- val abs_tm = mk_abs (scrut, case_tm)
-
- (* Add the “measure term” *)
- val tm = inst [alpha_tyvar |-> type_of scrut] measure_tm
- val tm = mk_comb (tm, abs_tm)
+ else (before_tys, sumSyntax.mk_inl (tm, sumSyntax.list_mk_sum after_tys))
+ val pat = foldr (fn (ty, pat) => sumSyntax.mk_inr (pat, ty)) pat before_tys
in
- tm
+ pat
end
-(*
-val ty = “: (num # 'a) + (num # 'b) + (num # 'c)”
-
-val tys = hd dtys
-val num_ty::rem_tys = tys
-
-val (tuple_var, tuple_case) = hd tuple_cases
-*)
-
-(* Get the smallest id which make the names unique (or to be more precise:
- such that the names don't correspond to already defined constants).
-
- We do this for {!mk_fuel_defs}: for some reason, the termination proof
- fails if we try to reuse the same names as before.
+(* This function wraps a term into the proper variant of the input/output
+ sum.
+
+ Ex.:
+ ====
+ For the input of the first function, we generate: ‘INL x’
+ For the output of the first function, we generate: ‘INR (INL x)’
+ For the input of the 2nd function, we generate: ‘INR (INR (INL x))’
+ etc.
+
+ If ‘is_input’ is true: we are wrapping an input. Otherwise we are wrapping
+ an output.
+
+ (* Debug *)
+ val tys = [(“:'a”, “:'b”), (“:'c”, “:'d”), (“:'e”, “:'f”)]
+ val j = 1
+ val tm = “x:'c”
+ val tm = “y:'d”
+ val is_input = true
*)
-fun get_smallest_unique_id_for_names (names : string list) : string =
+fun inject_in_param_sum (tys : (hol_type * hol_type) list) (j : int) (is_input : bool)
+ (tm : term) : term =
let
- (* Not trying to be smart here *)
- val i : int option ref = ref NONE
- fun get_i () = case !i of NONE => "" | SOME i => int_to_string i
- fun incr_i () =
- i := (case !i of NONE => SOME 0 | SOME i => SOME (i+1))
- val continue = ref true
- fun name_is_ok (name : string) : bool =
- not (is_const (Parse.parse_in_context [] [QUOTE (name ^ get_i ())]))
- handle HOL_ERR _ => false
- val _ =
- while !continue do (
- let val _ = (continue := not (forall name_is_ok names)) in
- if !continue then incr_i () else () end
- )
+ fun flatten ls = List.concat (map (fn (x, y) => [x, y]) ls)
+ val before_tys = flatten (List.take (tys, j))
+ val (input_ty, output_ty) = List.nth (tys, j)
+ val after_tys = flatten (List.drop (tys, j + 1))
+ val (before_tys, after_tys) =
+ if is_input then (before_tys, output_ty :: after_tys)
+ else (before_tys @ [input_ty], after_tys)
in
- get_i ()
+ list_mk_inl_inr before_tys tm after_tys
end
-fun mk_fuel_defs (def_tms : term list) : thm list =
- let
- (* Retrieve the identifiers.
+(* Remark: the order of the branches when creating matches is important.
+ For instance, in the case of ‘result’ it must be: ‘Return’, ‘Fail’, ‘Diverge’.
- Ex.: def_tm = “even (n : int) : bool result = if i = 0 then Return T else odd (i - 1))”
- We want to retrive: id = “even”
- *)
- val ids = map (fst o strip_comb o lhs) def_tms
-
- (* In the definitions, replace the identifiers by new identifiers which use
- fuel.
-
- Ex.:
- def_fuel_tm = “
- even___fuel (fuel : nat) (n : int) : result bool =
- case fuel of 0 => Diverge
- | SUC fuel' =>
- if i = 0 then Return T else odd_fuel fuel' (i - 1))”
- *)
- val names = map ((fn s => s ^ fuel_def_suffix) o fst o dest_var) ids
- val index = get_smallest_unique_id_for_names names
- fun mk_fuel_id (id : term) : term =
- let
- val (id_str, ty) = dest_var id
- (* Note: we use symbols forbidden in the generation of code to
- prevent name collisions *)
- val fuel_id_str = id_str ^ fuel_def_suffix ^ index
- val fuel_id = mk_var (fuel_id_str, num_ty --> ty)
- in fuel_id end
- val fuel_ids = map mk_fuel_id ids
-
- val fuel_ids_with_fuel0 = map (fn id => mk_comb (id, fuel_var0)) fuel_ids
- val fuel_ids_with_fuel1 = map (fn id => mk_comb (id, fuel_var1)) fuel_ids
-
- (* Recurse through the terms and replace the calls *)
- val rwr_thms0 = map (ASSUME o mk_eq) (zip ids fuel_ids_with_fuel0)
- val rwr_thms1 = map (ASSUME o mk_eq) (zip ids fuel_ids_with_fuel1)
-
- fun mk_fuel_tm (def_tm : term) : term =
- let
- val (tm0, tm1) = dest_eq def_tm
- val tm0 = (rhs o concl o (PURE_REWRITE_CONV rwr_thms0)) tm0
- val tm1 = (rhs o concl o (PURE_REWRITE_CONV rwr_thms1)) tm1
- in mk_eq (tm0, tm1) end
- val fuel_tms = map mk_fuel_tm def_tms
-
- (* Add the case over the fuel *)
- fun add_fuel_case (tm : term) : term =
- let
- val (f, body) = dest_eq tm
- (* Create the “Diverge” term with the proper type *)
- val body_ty = type_of body
- val return_ty =
- case (snd o dest_type) body_ty of [ty] => ty
- | _ => failwith "unexpected"
- val diverge_tm = mk_diverge_tm return_ty
- (* Create the “SUC fuel” term *)
- val suc_tm = mk_comb (num_suc_tm, fuel_var1)
- val fuel_tm =
- TypeBase.mk_case (fuel_var0, [(num_zero_tm, diverge_tm), (suc_tm, body)])
- in mk_eq (f, fuel_tm) end
- val fuel_tms = map add_fuel_case fuel_tms
-
- (* Define the auxiliary definitions which use fuel *)
- val fuel_defs_conj = list_mk_conj fuel_tms
- (* The definition name *)
- val def_name = (fst o dest_var o hd) fuel_ids
- (* The tactic to prove the termination *)
- val rty = ref “:bool” (* This is useful for debugging *)
- fun prove_termination_tac (asms, g) =
- let
- val r_tm = (fst o dest_exists) g
- val _ = rty := type_of r_tm
- val ty = (hd o snd o dest_type) (!rty)
- val m_tm = mk_termination_measure_from_ty ty
- in
- WF_REL_TAC ‘^m_tm’ (asms, g)
- end
-
- (* Define the fuel definitions *)
- (*
- val temp_def = Hol_defn def_name ‘^fuel_defs_conj’
- Defn.tgoal temp_def
- *)
- val fuel_defs = tDefine def_name ‘^fuel_defs_conj’ prove_termination_tac
- in
- CONJUNCTS fuel_defs
- end
-
-(*
-val (fuel_tms, fuel_defs) = mk_fuel_defs def_tms
-val fuel_def_tms = map (snd o strip_forall) ((strip_conj o concl) fuel_defs)
-val (def_tm, fuel_def_tm) = hd (zip def_tms fuel_def_tms)
-*)
-
-fun mk_is_diverge_tm (fuel_tm : term) : term =
- case snd (dest_type (type_of fuel_tm)) of
- [ret_ty] => mk_comb (inst [alpha_tyvar |-> ret_ty] is_diverge_tm, fuel_tm)
- | _ => failwith "mk_is_diverge_tm: unexpected"
-
-fun mk_fuel_predicate_defs (def_tm, fuel_def_tm) : thm =
+ For the purpose of stability and maintainability, we introduce this small helper
+ which reorders the cases in a pattern before actually creating the case
+ expression.
+ *)
+fun unordered_mk_case (scrut: term, pats: (term * term) list) : term =
let
- (* From [even i] create the term [even_P i n], where [n] is the fuel *)
- val (id, args) = (strip_comb o lhs) def_tm
- val (id_str, id_ty) = dest_var id
- val (tys, ret_ty) = strip_arrows id_ty
- val tys = append tys [num_ty]
- val pred_ty = list_mk_arrow tys bool_ty
- val pred_id = mk_var (id_str ^ fuel_predicate_suffix, pred_ty)
- val pred_tm = list_mk_comb (pred_id, append args [fuel_var])
-
- (* Create the term ~is_diverge (even_fuel n i) *)
- val fuel_tm = lhs fuel_def_tm
- val not_is_diverge_tm = mk_neg (mk_is_diverge_tm fuel_tm)
-
- (* Create the term: even_P i n = ~(is_diverge (even_fuel n i) *)
- val pred_def_tm = mk_eq (pred_tm, not_is_diverge_tm)
+ (* Retrieve the constructors *)
+ val cl = TypeBase.constructors_of (type_of scrut)
+ (* Retrieve the names of the constructors *)
+ val names = map (fst o dest_const) cl
+ (* Use those to reorder the patterns *)
+ fun is_pat (name : string) (pat, _) =
+ let
+ val app = (fst o strip_comb) pat
+ val app_name = (fst o dest_const) app
+ in
+ app_name = name
+ end
+ val pats = map (fn name => valOf (List.find (is_pat name) pats)) names
in
- (* Create the definition *)
- Define ‘^pred_def_tm’
+ (* Create the case *)
+ TypeBase.mk_case (scrut, pats)
end
-(*
-val (def_tm, fuel_def_tm) = hd (zip def_tms fuel_def_tms)
-
-val pred_defs = map mk_fuel_predicate_defs (zip def_tms fuel_def_tms)
-*)
+(* Wrap a term of type “:'a result” into a ‘case of’ which matches over
+ the result.
-(* Tactic which makes progress in a proof by making a case disjunction (we use
- this to explore all the paths in a function body). *)
-fun case_progress (asms, g) =
- let
- val scrut = (strip_all_cases_get_scrutinee o lhs) g
- in Cases_on ‘^scrut’ (asms, g) end
+ Ex.:
+ ====
+ {[
+ f x
-(* Prove the fuel monotonicity properties.
+ ~~>
- We want to prove a theorem of the shape:
- {[
- !n m.
- (!i. n <= m ==> even___P i n ==> even___fuel n i = even___fuel m i) /\
- (!i. n <= m ==> odd___P i n ==> odd___fuel n i = odd___fuel m i)
+ case f x of
+ | Fail e => Fail e
+ | Diverge => Diverge
+ | Return y => ... (* The branch content is generated by the continuation *)
]}
-*)
-fun prove_fuel_mono (pred_defs : thm list) (fuel_defs : thm list) : thm =
- let
- val pred_tms = map (lhs o snd o strip_forall o concl) pred_defs
- val fuel_tms = map (lhs o snd o strip_forall o concl) fuel_defs
- val pred_fuel_tms = zip pred_tms fuel_tms
- (* Create a set containing the names of all the functions in the recursive group *)
- val rec_fun_set =
- Redblackset.fromList const_name_compare (map get_fun_name_from_app fuel_tms)
- (* Small tactic which rewrites the occurrences of recursive calls *)
- fun rewrite_rec_call (asms, g) =
- let
- val scrut = (strip_all_cases_get_scrutinee o lhs) g
- val fun_id = get_fun_name_from_app scrut (* This can fail *)
- in
- (* Check if the function is part of the group we are considering *)
- if Redblackset.member (rec_fun_set, fun_id) then
- let
- (* Yes: use the induction hypothesis *)
- fun apply_ind_hyp (ind_th : thm) : tactic =
- let
- val th = SPEC_ALL ind_th
- val th_pat = (lhs o snd o strip_imp o concl) th
- val (var_s, ty_s) = match_term th_pat scrut
- (* Note that in practice the type instantiation should be empty *)
- val th = INST var_s (INST_TYPE ty_s th)
- in
- assume_tac th
- end
- in
- (last_assum apply_ind_hyp >> fs []) (asms, g)
- end
- else all_tac (asms, g)
- end
- handle HOL_ERR _ => all_tac (asms, g)
- (* Generate terms of the shape:
- !i. n <= m ==> even___P i n ==> even___fuel n i = even___fuel m i
- *)
- fun mk_fuel_eq_tm (pred_tm, fuel_tm) : term =
- let
- (* Retrieve the variables which are not the fuel - for the quantifiers *)
- val vars = (tl o snd o strip_comb) fuel_tm
- (* Introduce the fuel term which uses “m” *)
- val m_fuel_tm = subst [fuel_var0 |-> fuel_var1] fuel_tm
- (* Introduce the equality *)
- val fuel_eq_tm = mk_eq (fuel_tm, m_fuel_tm)
- (* Introduce the implication with the _P pred *)
- val fuel_eq_tm = mk_imp (pred_tm, fuel_eq_tm)
- (* Introduce the “n <= m ==> ...” implication *)
- val fuel_eq_tm = mk_imp (fuel_vars_le, fuel_eq_tm)
- (* Quantify *)
- val fuel_eq_tm = list_mk_forall (vars, fuel_eq_tm)
- in
- fuel_eq_tm
- end
- val fuel_eq_tms = map mk_fuel_eq_tm pred_fuel_tms
- (* Create the conjunction *)
- val fuel_eq_tms = list_mk_conj fuel_eq_tms
- (* Qantify over the fuels *)
- val fuel_eq_tms = list_mk_forall ([fuel_var0, fuel_var1], fuel_eq_tms)
- (* The tactics for the proof *)
- val prove_tac =
- Induct_on ‘^fuel_var0’ >-(
- (* The ___P predicates are false: n is 0 *)
- fs pred_defs >>
- fs [is_diverge_def] >>
- pure_once_rewrite_tac fuel_defs >> fs []) >>
- (* Introduce n *)
- gen_tac >>
- (* Introduce m *)
- Cases_on ‘^fuel_var1’ >-(
- (* Contradiction: SUC n < 0 *)
- rw [] >> exfalso >> int_tac) >>
- fs pred_defs >>
- fs [is_diverge_def] >>
- pure_once_rewrite_tac fuel_defs >> fs [bind_def] >>
- (* Introduce in the context *)
- rpt gen_tac >>
- (* Split the goals - note that we prove one big goal for all the functions at once *)
- rpt strip_tac >>
- (* Instantiate the assumption: !m. n <= m ==> ~(...)
- with the proper m.
- *)
- last_x_assum imp_res_tac >>
- (* Make sure the induction hypothesis is always the last assumption *)
- last_x_assum assume_tac >>
- (* Split the goals *)
- rpt strip_tac >> fs [case_result_same_eq] >>
- (* Explore all the paths *)
- rpt (rewrite_rec_call >> case_progress >> fs [case_result_same_eq])
- in
- (* Prove *)
- save_goal_and_prove (fuel_eq_tms, prove_tac)
- end
-(*
-val fuel_mono_thm = prove_fuel_mono pred_defs fuel_defs
+ ‘gen_ret_branch’ is a *continuation* which generates the content of the
+ ‘Return’ branch (i.e., the content of the ‘...’ in the example above).
+ It receives as input the value contained by the ‘Return’ (i.e., the variable
+ ‘y’ in the example above).
-set_goal ([], fuel_eq_tms)
-*)
+ Remark.: the type of the term generated by ‘gen_ret_branch’ must have
+ the type ‘result’, but it can change the content of the result (i.e.,
+ if ‘scrut’ has type ‘:'a result’, we can change the type of the wrapped
+ expression to ‘:'b result’).
-(* Prove the property about the least upper bound.
+ (* Debug *)
+ val scrut = “x: int result”
+ fun gen_ret_branch x = mk_return x
- We want to prove theorems of the shape:
- {[
- (!n i. $LEAST (even___P i) <= n ==> even___fuel n i = even___fuel ($LEAST (even___P i)) i)
- ]}
- {[
- (!n i. $LEAST (odd___P i) <= n ==> odd___fuel n i = odd___fuel ($LEAST (odd___P i)) i)
- ]}
+ val scrut = “x: int result”
+ fun gen_ret_branch _ = “Return T”
- TODO: merge with other functions? (prove_pred_imp_fuel_eq_raw_thms)
-*)
-fun prove_least_fuel_mono (pred_defs : thm list) (fuel_mono_thm : thm) : thm list =
+ mk_result_case scrut gen_ret_branch
+ *)
+fun mk_result_case (scrut : term) (gen_ret_branch : term -> term) : term =
let
- val thl = (CONJUNCTS o SPECL [fuel_var0, fuel_var1]) fuel_mono_thm
- fun mk_least_fuel_thm (pred_def, mono_thm) : thm =
- let
- (* Retrieve the predicate, without the fuel *)
- val pred_tm = (lhs o snd o strip_forall o concl) pred_def
- val (pred_tm, args) = strip_comb pred_tm
- val args = rev (tl (rev args))
- val pred_tm = list_mk_comb (pred_tm, args)
- (* Add $LEAST *)
- val least_pred_tm = mk_comb (least_tm, pred_tm)
- (* Specialize all *)
- val vars = (fst o strip_forall o concl) mono_thm
- val th = SPECL vars mono_thm
- (* Substitute in the mono theorem *)
- val th = INST [fuel_var0 |-> least_pred_tm] th
- (* Symmetrize the equality *)
- val th = PURE_ONCE_REWRITE_RULE [EQ_SYM_EQ] th
- (* Quantify *)
- val th = GENL (fuel_var1 :: vars) th
- in
- th
- end
+ val scrut_ty = dest_result (type_of scrut)
+ (* Return branch *)
+ val ret_var = genvar scrut_ty
+ val ret_pat = mk_return ret_var
+ val ret_br = gen_ret_branch ret_var
+ val ret_ty = dest_result (type_of ret_br)
+ (* Failure branch *)
+ val fail_var = genvar error_ty
+ val fail_pat = mk_fail scrut_ty fail_var
+ val fail_br = mk_fail ret_ty fail_var
+ (* Diverge branch *)
+ val div_pat = mk_diverge scrut_ty
+ val div_br = mk_diverge ret_ty
in
- map mk_least_fuel_thm (zip pred_defs thl)
+ unordered_mk_case (scrut, [(ret_pat, ret_br), (fail_pat, fail_br), (div_pat, div_br)])
end
-(*
-val (pred_def, mono_thm) = hd (zip pred_defs thl)
-*)
+(* Generate a ‘case ... of’ over a sum type.
-(* Prove theorems of the shape:
+ Ex.:
+ ====
+ If the scrutinee is: “x : 'a + 'b + 'c” (i.e., the tys list is: [“:'a”, “:b”, “:c”]),
+ we generate:
{[
- !n i. even___P i n ==> $LEAST (even___P i) <= n
+ case x of
+ | INL y0 => ... (* Branch of index 0 *)
+ | INR (INL y1) => ... (* Branch of index 1 *)
+ | INR (INR (INL y2)) => ... (* Branch of index 2 *)
+ | INR (INR (INR y3)) => ... (* Branch of index 3 *)
]}
- TODO: merge with other functions? (prove_pred_imp_fuel_eq_raw_thms)
+ The content of the branches is generated by the ‘gen_branch’ continuation,
+ which receives as input the index of the branch as well as the variable
+ introduced by the pattern (in the example above: ‘y0’ for the branch 0,
+ ‘y1’ for the branch 1, etc.)
+
+ (* Debug *)
+ val tys = [“:'a”, “:'b”]
+ val scrut = mk_var ("x", sumSyntax.list_mk_sum tys)
+ fun gen_branch i (x : term) = “F”
+
+ val tys = [“:'a”, “:'b”, “:'c”, “:'d”]
+ val scrut = mk_var ("x", sumSyntax.list_mk_sum tys)
+ fun gen_branch i (x : term) = if type_of x = “:'c” then mk_return x else mk_fail_failure “:'c”
+
+ list_mk_sum_case scrut tys gen_branch
*)
-fun prove_least_pred_thms (pred_defs : thm list) : thm list =
+(* For debugging *)
+val list_mk_sum_case_case = ref (“T”, [] : (term * term) list)
+(*
+val (scrut, [(pat1, br1), (pat2, br2)]) = !list_mk_sum_case_case
+*)
+fun list_mk_sum_case (scrut : term) (tys : hol_type list)
+ (gen_branch : int -> term -> term) : term =
let
- fun prove_least_pred_thm (pred_def : thm) : thm =
+ (* Create the cases. Note that without sugar, the match actually looks like this:
+ {[
+ case x of
+ | INL y0 => ... (* Branch of index 0 *)
+ | INR x1
+ case x1 of
+ | INL y1 => ... (* Branch of index 1 *)
+ | INR x2 =>
+ case x2 of
+ | INL y2 => ... (* Branch of index 2 *)
+ | INR y3 => ... (* Branch of index 3 *)
+ ]}
+ *)
+ fun create_case (j : int) (scrut : term) (tys : hol_type list) : term =
let
- val pred_tm = (lhs o snd o strip_forall o concl) pred_def
- val (pred_no_fuel_tm, args) = strip_comb pred_tm
- val args = rev (tl (rev args))
- val pred_no_fuel_tm = list_mk_comb (pred_no_fuel_tm, args)
- (* Make the “$LEAST (even___P i)” term *)
- val least_pred_tm = mk_comb (least_tm, pred_no_fuel_tm)
- (* Make the inequality *)
- val tm = list_mk_comb (le_tm, [least_pred_tm, fuel_var0])
- (* Add the implication *)
- val tm = mk_imp (pred_tm, tm)
- (* Quantify *)
- val tm = list_mk_forall (args, tm)
- val tm = mk_forall (fuel_var0, tm)
- (* Prove *)
- val prove_tac =
- rpt gen_tac >>
- disch_tac >>
- (* Use the "fundamental" property about $LEAST *)
- qspec_assume ‘^pred_no_fuel_tm’ whileTheory.LEAST_EXISTS_IMP >>
- (* Prove the premise *)
- pop_assum sg_premise_tac >- (exists_tac fuel_var0 >> fs []) >>
- rw [] >>
- (* Finish the proof by contraposition *)
- spose_not_then assume_tac >>
- fs [not_le_eq_gt]
+ val _ = print_dbg ("list_mk_sum_case: " ^
+ String.concatWith ", " (map type_to_string tys) ^ "\n")
in
- save_goal_and_prove (tm, prove_tac)
+ case tys of
+ [] => failwith "tys is too short"
+ | [ ty ] =>
+ (* Last element: no match to perform *)
+ gen_branch j scrut
+ | ty1 :: tys =>
+ (* Not last: we create a pattern:
+ {[
+ case scrut of
+ | INL pat_var1 => ... (* Branch of index i *)
+ | INR pat_var2 =>
+ ... (* Generate this term recursively *)
+ ]}
+ *)
+ let
+ (* INL branch *)
+ val after_ty = sumSyntax.list_mk_sum tys
+ val pat_var1 = genvar ty1
+ val pat1 = sumSyntax.mk_inl (pat_var1, after_ty)
+ val br1 = gen_branch j pat_var1
+ (* INR branch *)
+ val pat_var2 = genvar after_ty
+ val pat2 = sumSyntax.mk_inr (pat_var2, ty1)
+ val br2 = create_case (j+1) pat_var2 tys
+ val _ = print_dbg ("list_mk_sum_case: assembling:\n" ^
+ term_to_string scrut ^ ",\n" ^
+ "[(" ^ term_to_string pat1 ^ ",\n " ^ term_to_string br1 ^ "),\n\n" ^
+ " (" ^ term_to_string pat2 ^ ",\n " ^ term_to_string br2 ^ ")]\n\n")
+ val case_elems = (scrut, [(pat1, br1), (pat2, br2)])
+ val _ = list_mk_sum_case_case := case_elems
+ in
+ (* Put everything together *)
+ TypeBase.mk_case case_elems
+ end
end
in
- map prove_least_pred_thm pred_defs
+ create_case 0 scrut tys
end
+(* Generate a ‘case ... of’ to select the input/output of the ith variant of
+ the param enumeration.
-(*
-val least_pred_thms = prove_least_pred_thms pred_defs
-
-val least_pred_thm = hd least_pred_thms
-*)
-
-(* Prove theorems of the shape:
-
+ Ex.:
+ ====
+ There are two functions in the group, and we select the input of the function of index 1:
{[
- !n i. even___P i n ==> even___P i ($LEAST (even___P i))
+ case x of
+ | INL _ => Fail Failure (* Input of function of index 0 *)
+ | INR (INL _) => Fail Failure (* Output of function of index 0 *)
+ | INR (INR (INL y)) => Return y (* Input of the function of index 1: select this one *)
+ | INR (INR (INR _)) => Fail Failure (* Output of the function of index 1 *)
]}
-*)
-fun prove_pred_n_imp_pred_least_thms (pred_defs : thm list) : thm list =
- let
- fun prove_pred_n_imp_pred_least (pred_def : thm) : thm =
- let
- val pred_tm = (lhs o snd o strip_forall o concl) pred_def
- val (pred_no_fuel_tm, args) = strip_comb pred_tm
- val args = rev (tl (rev args))
- val pred_no_fuel_tm = list_mk_comb (pred_no_fuel_tm, args)
- (* Make the “$LEAST (even___P i)” term *)
- val least_pred_tm = mk_comb (least_tm, pred_no_fuel_tm)
- (* Make the “even___P i ($LEAST (even___P i))” *)
- val tm = subst [fuel_var0 |-> least_pred_tm] pred_tm
- (* Add the implication *)
- val tm = mk_imp (pred_tm, tm)
- (* Quantify *)
- val tm = list_mk_forall (args, tm)
- val tm = mk_forall (fuel_var0, tm)
- (* The proof tactic *)
- val prove_tac =
- rpt gen_tac >>
- disch_tac >>
- (* Use the "fundamental" property about $LEAST *)
- qspec_assume ‘^pred_no_fuel_tm’ whileTheory.LEAST_EXISTS_IMP >>
- (* Prove the premise *)
- pop_assum sg_premise_tac >- (exists_tac fuel_var0 >> fs []) >>
- rw []
- in
- save_goal_and_prove (tm, prove_tac)
- end
- in
- map prove_pred_n_imp_pred_least pred_defs
- end
-(*
-val (pred_def, mono_thm) = hd (zip pred_defs thl)
-val least_fuel_mono_thms = prove_least_fuel_mono pred_defs fuel_defs fuel_mono_thm
+ (* Debug *)
+ val tys = [(“:'a”, “:'b”)]
+ val scrut = “x : 'a + 'b”
+ val fi = 0
+ val is_input = true
-val least_fuel_mono_thm = hd least_fuel_mono_thms
-*)
+ val tys = [(“:'a”, “:'b”), (“:'c”, “:'d”)]
+ val scrut = “x : 'a + 'b + 'c + 'd”
+ val fi = 1
+ val is_input = false
-(* Define the "raw" definitions:
+ val scrut = mk_var ("x", sumSyntax.list_mk_sum (flatten tys))
- {[
- even i = if (?n. even___P i n) then even___P ($LEAST (even___P i)) i else Diverge
- ]}
+ list_mk_case_select scrut tys fi is_input
*)
-fun define_raw_defs (def_tms : term list) (pred_defs : thm list) (fuel_defs : thm list) : thm list =
+fun list_mk_case_sum_select (scrut : term) (tys : (hol_type * hol_type) list)
+ (fi : int) (is_input : bool) : term =
let
- fun define_raw_def (def_tm, (pred_def, fuel_def)) : thm =
- let
- val app = lhs def_tm
- val pred_tm = (lhs o snd o strip_forall o concl) pred_def
- (* Make the “?n. even___P i n” term *)
- val exists_fuel_tm = mk_exists (fuel_var0, pred_tm)
- (* Make the “even___fuel ($LEAST (even___P i)) i” term *)
- val fuel_tm = (lhs o snd o strip_forall o concl) fuel_def
- val (pred_tm, args) = strip_comb pred_tm
- val args = rev (tl (rev args))
- val pred_tm = list_mk_comb (pred_tm, args)
- val least_pred_tm = mk_comb (least_tm, pred_tm)
- val fuel_tm = subst [fuel_var0 |-> least_pred_tm] fuel_tm
- (* Create the Diverge term *)
- val ret_ty = (hd o snd o dest_type) (type_of app)
- (* Create the “if then else” *)
- val body = TypeBase.mk_case (exists_fuel_tm, [(true_tm, fuel_tm), (false_tm, mk_diverge_tm ret_ty)])
- (* *)
- val raw_def_tm = mk_eq (app, body)
- in
- Define ‘^raw_def_tm’
- end
+ (* The index of the element in the enumeration that we will select *)
+ val i = 2 * fi + (if is_input then 0 else 1)
+ (* Flatten the types and numerotate them *)
+ fun flatten ls = List.concat (map (fn (x, y) => [x, y]) ls)
+ val tys = flatten tys
+ (* Get the return type *)
+ val ret_ty = List.nth (tys, i)
+ (* The continuation which will generate the content of the branches *)
+ fun gen_branch j var = if j = i then mk_return var else mk_fail_failure ret_ty
in
- map define_raw_def (zip def_tms (zip pred_defs fuel_defs))
+ (* Generate the ‘case ... of’ *)
+ list_mk_sum_case scrut tys gen_branch
end
-(*
-val raw_defs = define_raw_defs def_tms pred_defs fuel_defs
+(* Generate a ‘case ... of’ to select the input/output of the ith variant of
+ the param enumeration.
+
+ Ex.:
+ ====
+ There are two functions in the group, and we select the input of the function of index 1:
+ {[
+ case x of
+ | Fail e => Fail e
+ | Diverge => Diverge
+ | Return r =>
+ case r of
+ | INL _ => Fail Failure (* Input of function of index 0 *)
+ | INR (INL _) => Fail Failure (* Output of function of index 0 *)
+ | INR (INR (INL y)) => Return y (* Input of the function of index 1: select this one *)
+ | INR (INR (INR _)) => Fail Failure (* Output of the function of index 1 *)
+ ]}
*)
+fun mk_case_select_result_sum (scrut : term) (tys : (hol_type * hol_type) list)
+ (fi : int) (is_input : bool) : term =
+ (* We match over the result, then over the enumeration *)
+ mk_result_case scrut (fn x => list_mk_case_sum_select x tys fi is_input)
-(* Prove theorems of the shape:
+(* Generate a body for the fixed-point operator from a quoted group of mutually
+ recursive definitions.
- !n i. even___P i n ==> even___fuel n i = even i
+ See TODO for detailed explanations: from the quoted equations for ‘nth’
+ (or for [‘even’, ‘odd’]) we generate the body ‘nth_body’ (or ‘even_odd_body’,
+ respectively).
*)
-fun prove_pred_imp_fuel_eq_raw_defs
- (pred_defs : thm list)
- (fuel_def_tms : term list)
- (least_fuel_mono_thms : thm list)
- (least_pred_thms : thm list)
- (pred_n_imp_pred_least_thms : thm list)
- (raw_defs : thm list) :
- thm list =
+fun mk_body (fnames : string list) (in_out_tys : (hol_type * hol_type) list)
+ (def_tms : term list) : term =
let
- fun prove_thm (pred_def,
- (fuel_def_tm,
- (least_fuel_mono_thm,
- (least_pred_thm,
- (pred_n_imp_pred_least_thm, raw_def))))) : thm =
- let
- (* Generate: “even___P i n” *)
- val pred_tm = (lhs o snd o strip_forall o concl) pred_def
- val (pred_no_fuel_tm, args) = strip_comb pred_tm
- val args = rev (tl (rev args))
- (* Generate: “even___fuel n i” *)
- val fuel_tm = lhs fuel_def_tm
- (* Generate: “even i” *)
- val raw_def_tm = (lhs o snd o strip_forall o concl) raw_def
- (* Generate: “even___fuel n i = even i” *)
- val tm = mk_eq (fuel_tm, raw_def_tm)
- (* Add the implication *)
- val tm = mk_imp (pred_tm, tm)
- (* Quantify *)
- val tm = list_mk_forall (args, tm)
- val tm = mk_forall (fuel_var0, tm)
- (* Prove *)
- val prove_tac =
- rpt gen_tac >>
- strip_tac >>
- fs raw_defs >>
- (* Case on ‘?n. even___P i n’ *)
- CASE_TAC >> fs [] >>
- (* Use the monotonicity property *)
- irule least_fuel_mono_thm >>
- imp_res_tac pred_n_imp_pred_least_thm >> fs [] >>
- irule least_pred_thm >> fs []
- in
- save_goal_and_prove (tm, prove_tac)
- end
- in
- map prove_thm (zip pred_defs (zip fuel_def_tms (zip least_fuel_mono_thms
- (zip least_pred_thms (zip pred_n_imp_pred_least_thms raw_defs)))))
- end
+ val fnames_set = Redblackset.fromList String.compare fnames
+
+ (* Compute a map from function name to function index *)
+ val fnames_map = Redblackmap.fromList String.compare
+ (map (fn (x, y) => (y, x)) (enumerate fnames))
+
+ (* Compute the input/output type, that we dub the "parameter type" *)
+ fun flatten ls = List.concat (map (fn (x, y) => [x, y]) ls)
+ val param_type = sumSyntax.list_mk_sum (flatten in_out_tys)
+
+ (* Introduce a variable for the confinuation *)
+ val fcont = genvar (param_type --> mk_result param_type)
+
+ (* In the function equations, replace all the recursive calls with calls to the continuation.
+
+ When replacing a recursive call, we have to do two things:
+ - we need to inject the input parameters into the parameter type
+ Ex.:
+ - ‘nth tl i’ becomes ‘f (INL (tl, i))’ where ‘f’ is the continuation
+ - ‘even i’ becomes ‘f (INL i)’ where ‘f’ is the continuation
+ - we need to wrap the the call to the continuation into a ‘case ... of’
+ to extract its output (we need to make sure that the transformation
+ preserves the type of the expression!)
+ Ex.: ‘nth tl i’ becomes:
+ {[
+ case f (INL (tl, i)) of
+ | Fail e => Fail e
+ | Diverge => Diverge
+ | Return r =>
+ case r of
+ | INL _ => Fail Failure
+ | INR x => Return (INR x)
+ ]}
+ *)
+ (* For debugging *)
+ val replace_rec_calls_rec_call_tm = ref “T”
+ fun replace_rec_calls (fnames_set : string Redblackset.set) (tm : term) : term =
+ let
+ val _ = print_dbg ("replace_rec_calls: original expression:\n" ^
+ term_to_string tm ^ "\n\n")
+ val ntm =
+ case dest_term tm of
+ VAR (name, ty) =>
+ (* Check that this is not one of the functions in the group - remark:
+ we could handle that by introducing lambdas.
+ *)
+ if Redblackset.member (fnames_set, name)
+ then failwith ("mk_body: not well-formed definition: found " ^ name ^
+ " in an improper position")
+ else tm
+ | CONST _ => tm
+ | LAMB (x, tm) =>
+ let
+ (* The variable might shadow one of the functions *)
+ val fnames_set = Redblackset.delete (fnames_set, (fst o dest_var) x)
+ (* Update the term in the lambda *)
+ val tm = replace_rec_calls fnames_set tm
+ in
+ (* Reconstruct *)
+ mk_abs (x, tm)
+ end
+ | COMB (_, _) =>
+ let
+ (* Completely destruct the application, check if this is a recursive call *)
+ val (app, args) = strip_comb tm
+ val is_rec_call = Redblackset.member (fnames_set, (fst o dest_var) app)
+ handle HOL_ERR _ => false
+ (* Whatever the case, apply the transformation to all the inputs *)
+ val args = map (replace_rec_calls fnames_set) args
+ in
+ (* If this is not a recursive call: apply the transformation to all the
+ terms. Otherwise, replace. *)
+ if not is_rec_call then list_mk_comb (replace_rec_calls fnames_set app, args)
+ else
+ (* Rec call: replace *)
+ let
+ val _ = replace_rec_calls_rec_call_tm := tm
+ (* First, find the index of the function *)
+ val fname = (fst o dest_var) app
+ val fi = Redblackmap.find (fnames_map, fname)
+ (* Inject the input values into the param type *)
+ val input = pairSyntax.list_mk_pair args
+ val input = inject_in_param_sum in_out_tys fi true input
+ (* Create the recursive call *)
+ val call = mk_comb (fcont, input)
+ (* Wrap the call into a ‘case ... of’ to extract the output *)
+ val call = mk_case_select_result_sum call in_out_tys fi false
+ in
+ (* Return *)
+ call
+ end
+ end
+ val _ = print_dbg ("replace_rec_calls: new expression:\n" ^ term_to_string ntm ^ "\n\n")
+ in
+ ntm
+ end
+ handle HOL_ERR e =>
+ let
+ val _ = print_dbg ("replace_rec_calls: failed on:\n" ^ term_to_string tm ^ "\n\n")
+ in
+ raise (HOL_ERR e)
+ end
+ fun replace_rec_calls_in_eq (eq : term) : term =
+ let
+ val (l, r) = dest_eq eq
+ in
+ mk_eq (l, replace_rec_calls fnames_set r)
+ end
+ val def_tms_with_fcont = map replace_rec_calls_in_eq def_tms
-(*
-val pred_imp_fuel_eq_raw_defs =
- prove_pred_imp_fuel_eq_raw_defs
- pred_defs fuel_def_tms least_fuel_mono_thms least_pred_thms
- pred_n_imp_pred_least_thms raw_defs
- *)
+ (* Wrap all the function bodies to inject their result into the param type.
+ We collect the function inputs at the same time, because they will be
+ grouped into a tuple that we will have to deconstruct.
+ *)
+ fun inject_body_to_enums (i : int, def_eq : term) : term list * term =
+ let
+ val (l, body) = dest_eq def_eq
+ val (_, args) = strip_comb l
+ (* We have the deconstruct the result, then, in the ‘Return’ branch,
+ properly wrap the returned value *)
+ val body = mk_result_case body (fn x => mk_return (inject_in_param_sum in_out_tys i false x))
+ in
+ (args, body)
+ end
+ val def_tms_inject = map inject_body_to_enums (enumerate def_tms_with_fcont)
-(* Generate "expand" definitions of the following shape (we use them to
- hide the raw function bodies, to control the rewritings):
+ (* Currify the body inputs.
- {[
- even___expand even odd i : bool result =
- if i = 0 then Return T else odd (i - 1)
- ]}
+ For instance, if the body has inputs: ‘x’, ‘y’; we return the following:
+ {[
+ (‘z’, ‘case z of (x, y) => ... (* body *) ’)
+ ]}
+ where ‘z’ is fresh.
- {[
- odd___expand even odd i : bool result =
- if i = 0 then Return F else even (i - 1)
- ]}
+ We return: (curried input, body).
- *)
-fun gen_expand_defs (def_tms : term list) =
- let
- (* Generate the variables for “even”, “odd”, etc. *)
- val fun_vars = map (fst o strip_comb o lhs) def_tms
- val fun_tys = map type_of fun_vars
- (* Generate the expansion *)
- fun mk_def (def_tm : term) : thm =
- let
- val (exp_fun, args) = (strip_comb o lhs) def_tm
- val (exp_fun_str, exp_fun_ty) = dest_var exp_fun
- val exp_fun_str = exp_fun_str ^ expand_suffix
- val exp_fun_ty = list_mk_arrow fun_tys exp_fun_ty
- val exp_fun = mk_var (exp_fun_str, exp_fun_ty)
- val exp_fun = list_mk_comb (exp_fun, fun_vars)
- val exp_fun = list_mk_comb (exp_fun, args)
- val tm = mk_eq (exp_fun, rhs def_tm)
- in
- Define ‘^tm’
- end
+ (* Debug *)
+ val body = “(x:'a, y:'b, z:'c)”
+ val args = [“x:'a”, “y:'b”, “z:'c”]
+ currify_body_inputs (args, body)
+ *)
+ fun currify_body_inputs (args : term list, body : term) : term * term =
+ let
+ fun mk_curry (args : term list) (body : term) : term * term =
+ case args of
+ [] => failwith "no inputs"
+ | [x] => (x, body)
+ | x1 :: args =>
+ let
+ val (x2, body) = mk_curry args body
+ val scrut = genvar (pairSyntax.list_mk_prod (map type_of (x1 :: args)))
+ val pat = pairSyntax.mk_pair (x1, x2)
+ val br = body
+ in
+ (scrut, TypeBase.mk_case (scrut, [(pat, br)]))
+ end
+ in
+ mk_curry args body
+ end
+ val def_tms_currified = map currify_body_inputs def_tms_inject
+
+ (* Group all the functions into a single body, with an outer ‘case .. of’
+ which selects the appropriate body depending on the input *)
+ val param_ty = sumSyntax.list_mk_sum (flatten in_out_tys)
+ val input = genvar param_ty
+ fun mk_mut_rec_body_branch (i : int) (patvar : term) : term =
+ (* Case disjunction on whether the branch is for an input value (in
+ which case we call the proper body) or an output value (in which
+ case we return ‘Fail ...’ *)
+ if i mod 2 = 0 then
+ let
+ val fi = i div 2
+ val (x, def_tm) = List.nth (def_tms_currified, fi)
+ (* The variable in the pattern and the variable expected by the
+ body may not be the same: we introduce a let binding *)
+ val def_tm = mk_let (mk_abs (x, def_tm), patvar)
+ in
+ def_tm
+ end
+ else
+ (* Output value: fail *)
+ mk_fail_failure param_ty
+ val mut_rec_body = list_mk_sum_case input (flatten in_out_tys) mk_mut_rec_body_branch
+
+
+ (* Abstract away the parameters to produce the final body of the fixed point *)
+ val mut_rec_body = list_mk_abs ([fcont, input], mut_rec_body)
in
- map mk_def def_tms
+ mut_rec_body
end
-(*
-val def_tm = hd def_tms
+(*=============================================================================*
+ *
+ * Prove that the body satisfies the validity condition
+ *
+ * ============================================================================*)
-val expand_defs = gen_expand_defs def_tms
-*)
-
-(* Small utility:
-
- Return the list:
- {[
- (“even___P i n”, “even i = even___expand even odd i”),
- ...
- ]}
-
- *)
-fun mk_termination_diverge_tms
- (def_tms : term list)
- (pred_defs : thm list)
- (raw_defs : thm list)
- (expand_defs : thm list) :
- (term * term) list =
+(* Tactic to prove that a body is valid: perform one step. *)
+fun prove_body_is_valid_tac_step (asms, g) =
let
- (* Create the substitution for the "expand" functions:
+ (* The goal has the shape:
{[
- even -> even
- odd -> odd
- ...
- ]}
-
- where on the left we have *variables* and on the right we have
- the "raw" definitions.
+ (∀g h. ... g x = ... h x) ∨
+ ∃h y. is_valid_fp_body n h ∧ ∀g. ... g x = ... od
+ ]}
+ *)
+ (* Retrieve the scrutinee in the goal (‘x’).
+ There are two cases:
+ - either the function has the shape:
+ {[
+ (λ(y,z). ...) x
+ ]}
+ in which case we need to destruct ‘x’
+ - or we have a normal ‘case ... of’
*)
- fun mk_fun_subst (def_tm, raw_def) =
+ val body = (lhs o snd o strip_forall o fst o dest_disj) g
+ val scrut =
let
- val var = (fst o strip_comb o lhs) def_tm
- val f = (fst o strip_comb o lhs o snd o strip_forall o concl) raw_def
+ val (app, x) = dest_comb body
+ val (app, _) = dest_comb app
+ val {Name=name, Thy=thy, Ty = _ } = dest_thy_const app
in
- (var |-> f)
+ if thy = "pair" andalso name = "UNCURRY" then x else failwith "not a curried argument"
end
- val fun_subst = map mk_fun_subst (zip def_tms raw_defs)
-
- fun mk_tm (pred_def, (raw_def, expand_def)) :
- term * term =
+ handle HOL_ERR _ => strip_all_cases_get_scrutinee body
+ (* Retrieve the first quantified continuations from the goal (‘g’) *)
+ val qc = (hd o fst o strip_forall o fst o dest_disj) g
+ (* Check if the scrutinee is a recursive call *)
+ val (scrut_app, _) = strip_comb scrut
+ val _ = print_dbg ("prove_body_is_valid_step: Scrutinee: " ^ term_to_string scrut ^ "\n")
+ (* For the recursive calls: *)
+ fun step_rec () =
let
- (* “even___P i n” *)
- val pred_tm = (lhs o snd o strip_forall o concl) pred_def
- (* “even i = even___expand even odd i” *)
- val expand_tm = (lhs o snd o strip_forall o concl) expand_def
- val expand_tm = subst fun_subst expand_tm
- val fun_tm = (lhs o snd o strip_forall o concl) raw_def
- val fun_eq_tm = mk_eq (fun_tm, expand_tm)
- in (pred_tm, fun_eq_tm) end
+ val _ = print_dbg ("prove_body_is_valid_step: rec call\n")
+ (* We need to instantiate the ‘h’ existantially quantified function *)
+ (* First, retrieve the body of the function: it is given by the ‘Return’ branch *)
+ val (_, _, branches) = TypeBase.dest_case body
+ (* Find the branch corresponding to the return *)
+ val ret_branch = List.find (fn (pat, _) =>
+ let
+ val {Name=name, Thy=thy, Ty = _ } = (dest_thy_const o fst o strip_comb) pat
+ in
+ thy = "primitives" andalso name = "Return"
+ end) branches
+ val var = (hd o snd o strip_comb o fst o valOf) ret_branch
+ val br = (snd o valOf) ret_branch
+ (* Abstract away the input variable introduced by the pattern and the continuation ‘g’ *)
+ val h = list_mk_abs ([qc, var], br)
+ val _ = print_dbg ("prove_body_is_valid_step: h: " ^ term_to_string h ^ "\n")
+ (* Retrieve the input parameter ‘x’ *)
+ val input = (snd o dest_comb) scrut
+ val _ = print_dbg ("prove_body_is_valid_step: y: " ^ term_to_string input ^ "\n")
+ in
+ ((* Choose the right possibility (this is a recursive call) *)
+ disj2_tac >>
+ (* Instantiate the quantifiers *)
+ qexists ‘^h’ >>
+ qexists ‘^input’ >>
+ (* Unfold the predicate once *)
+ pure_once_rewrite_tac [is_valid_fp_body_def] >>
+ (* We have two subgoals:
+ - we have to prove that ‘h’ is valid
+ - we have to finish the proof of validity for the current body
+ *)
+ conj_tac >> fs [case_result_switch_eq])
+ end
in
- map mk_tm (zip pred_defs (zip raw_defs expand_defs))
+ (* If recursive call: special treatment. Otherwise, we do a simple disjunction *)
+ (if term_eq scrut_app qc then step_rec ()
+ else (Cases_on ‘^scrut’ >> fs [case_result_switch_eq])) (asms, g)
end
-(*
-val term_div_tms =
- mk_termination_diverge_tms pred_defs raw_defs expand_defs
-*)
+(* Tactic to prove that a body is valid *)
+fun prove_body_is_valid_tac (body_def : thm option) : tactic =
+ let val body_def_thm = case body_def of SOME th => [th] | NONE => []
+ in
+ pure_once_rewrite_tac [is_valid_fp_body_def] >> gen_tac >>
+ (* Expand *)
+ fs body_def_thm >>
+ fs [bind_def, case_result_switch_eq] >>
+ (* Explore the body *)
+ rpt prove_body_is_valid_tac_step
+ end
-(* Prove the termination lemmas:
-
- {[
- !i.
- (?n. even___P i n) ==>
- even i = even___expand even odd i
- ]}
- *)
-fun prove_termination_thms
- (term_div_tms : (term * term) list)
- (fuel_defs : thm list)
- (pred_defs : thm list)
- (raw_defs : thm list)
- (expand_defs : thm list)
- (pred_n_imp_pred_least_thms : thm list)
- (pred_imp_fuel_eq_raw_defs : thm list)
- : thm list =
+(* Prove that a body satisfies the validity condition of the fixed point *)
+fun prove_body_is_valid (body : term) : thm =
let
- (* Create a map from functions in the recursive group to lemmas
- to apply *)
- fun mk_rec_fun_eq_pair (fuel_def, eq_th) =
- let
- val rfun = (get_fun_name_from_app o lhs o snd o strip_forall o concl) fuel_def
- in
- (rfun, eq_th)
- end
- val rec_fun_eq_map =
- Redblackmap.fromList const_name_compare (
- map mk_rec_fun_eq_pair
- (zip fuel_defs pred_imp_fuel_eq_raw_defs))
-
- (* Small tactic which rewrites the recursive calls *)
- fun rewrite_rec_call (asms, g) =
- let
- val scrut = (strip_all_cases_get_scrutinee o lhs) g
- val fun_id = get_fun_name_from_app scrut (* This can fail *)
- (* This can raise an exception - hence the handle at the end
- of the function *)
- val eq_th = Redblackmap.find (rec_fun_eq_map, fun_id)
- val eq_th = (UNDISCH_ALL o SPEC_ALL) eq_th
- (* Match the theorem *)
- val eq_th_tm = (lhs o concl) eq_th
- val (var_s, ty_s) = match_term eq_th_tm scrut
- val eq_th = INST var_s (INST_TYPE ty_s eq_th)
- val eq_th = thm_to_conj_implies eq_th
- (* Some tactics *)
- val premise_tac = fs pred_defs >> fs [is_diverge_def]
- in
- (* Apply the theorem, prove the premise, and rewrite *)
- (prove_premise_then premise_tac assume_tac eq_th >> fs []) (asms, g)
- end handle NotFound => all_tac (asms, g)
- | HOL_ERR _ => all_tac (asms, g) (* Getting the function name can also fail *)
-
- fun prove_one ((pred_tm, fun_eq_tm), pred_n_imp_pred_least_thm) :
- thm =
- let
- (* “?n. even___P i n” *)
- val pred_tm = mk_exists (fuel_var0, pred_tm)
- (* “even i = even___expand even odd i” *)
- val tm = fun_eq_tm
- (* Add the implication *)
- val tm = mk_imp (pred_tm, tm)
- (* Quantify *)
- val (_, args) = strip_comb (lhs fun_eq_tm)
- val tm = list_mk_forall (args, tm)
-
- (* Prove *)
- val prove_tac =
- rpt gen_tac >>
- disch_tac >>
-
- (* Expand the raw definition and get rid of the ‘?n ...’ *)
- pure_once_rewrite_tac raw_defs >>
- pure_asm_rewrite_tac [] >>
-
- (* Simplify *)
- fs [] >>
-
- (* Prove that: “even___P i $(LEAST ...)” *)
- imp_res_tac pred_n_imp_pred_least_thm >>
-
- (* We don't need the ‘even___P i n’ assumption anymore: we have a more
- precise one with the least upper bound *)
- last_x_assum ignore_tac >>
-
- (* Expand *)
- fs pred_defs >>
- fs [is_diverge_def] >>
- fs expand_defs >>
-
- (* We need to be a bit careful when expanding the definitions which use fuel:
- it can make the simplifier loop. *)
- rpt (pop_assum mp_tac) >>
- pure_once_rewrite_tac fuel_defs >>
- rpt disch_tac >>
+ (* Explore the body and count the number of occurrences of nested recursive
+ calls so that we can properly instantiate the ‘N’ argument of ‘is_valid_fp_body’.
+
+ We first retrieve the name of the continuation parameter.
+ Rem.: we generated fresh names so that, for instance, the continuation name
+ doesn't collide with other names. Because of this, we don't need to look for
+ collisions when exploring the body (and in the worst case, we would cound
+ an overapproximation of the number of recursive calls, which is perfectly
+ valid).
+ *)
+ val fcont = (hd o fst o strip_abs) body
+ val fcont_name = (fst o dest_var) fcont
+ fun max x y = if x > y then x else y
+ fun count_body_rec_calls (body : term) : int =
+ case dest_term body of
+ VAR (name, _) => if name = fcont_name then 1 else 0
+ | CONST _ => 0
+ | COMB (x, y) => max (count_body_rec_calls x) (count_body_rec_calls y)
+ | LAMB (_, x) => count_body_rec_calls x
+ val num_rec_calls = count_body_rec_calls body
+
+ (* Generate the term ‘SUC (SUC ... (SUC n))’ where ‘n’ is a fresh variable.
+
+ Remark: we first prove ‘is_valid_fp_body (SUC ... n) body’ then substitue
+ ‘n’ with ‘0’ to prevent the quantity from being rewritten to a bit
+ representation, which would prevent unfolding of the ‘is_valid_fp_body’.
+ *)
+ val nvar = genvar num_ty
+ (* Rem.: we stack num_rec_calls + 1 occurrences of ‘SUC’ (and the + 1 is important) *)
+ fun mk_n i = if i = 0 then mk_suc nvar else mk_suc (mk_n (i-1))
+ val n_tm = mk_n num_rec_calls
- (* Expand the binds *)
- fs [bind_def, case_result_same_eq] >>
+ (* Generate the lemma statement *)
+ val is_valid_tm = list_mk_icomb (is_valid_fp_body_tm, [n_tm, body])
+ val is_valid_thm = prove (is_valid_tm, prove_body_is_valid_tac NONE)
- (* Explore all the paths by doing case disjunctions *)
- rpt (rewrite_rec_call >> case_progress >> fs [case_result_same_eq])
- in
- save_goal_and_prove (tm, prove_tac)
- end
+ (* Replace ‘nvar’ with ‘0’ *)
+ val is_valid_thm = INST [nvar |-> zero_num_tm] is_valid_thm
in
- map prove_one
- (zip term_div_tms pred_n_imp_pred_least_thms)
+ is_valid_thm
end
-(*
-val termination_thms =
- prove_termination_thms term_div_tms fuel_defs pred_defs
- raw_defs expand_defs pred_n_imp_pred_least_thms
- pred_imp_fuel_eq_raw_defs
-
-val ((pred_tm, fun_eq_tm), pred_n_imp_pred_least_thm) = hd (zip term_div_tms pred_n_imp_pred_least_thms)
-set_goal ([], tm)
-*)
-
-(* Prove the divergence lemmas:
-
- {[
- !i.
- (!n. ~even___P i n) ==>
- (!n. ~even___P i (SUC n)) ==>
- even i = even___expand even odd i
- ]}
+(*=============================================================================*
+ *
+ * Generate the definitions with the fixed-point operator
+ *
+ * ============================================================================*)
- Note that the shape of the theorem is very precise: this helps for the proof.
- Also, by correctly ordering the assumptions, we make sure that by rewriting
- we don't convert one of the two to “T”.
- *)
-fun prove_divergence_thms
- (term_div_tms : (term * term) list)
- (fuel_defs : thm list)
- (pred_defs : thm list)
- (raw_defs : thm list)
- (expand_defs : thm list)
- : thm list =
+(* Generate the raw definitions by using the grouped definition body and the
+ fixed point operator *)
+fun mk_raw_defs (in_out_tys : (hol_type * hol_type) list)
+ (def_tms : term list) (body_is_valid : thm) : thm list =
let
- (* Create a set containing the names of all the functions in the recursive group *)
- fun get_rec_fun_id (fuel_def : thm) =
- (get_fun_name_from_app o lhs o snd o strip_forall o concl) fuel_def
- val rec_fun_set =
- Redblackset.fromList const_name_compare (
- map get_rec_fun_id raw_defs)
-
- (* Small tactic which rewrites the recursive calls *)
- fun rewrite_rec_call (asms, g) =
- let
- val scrut = (strip_all_cases_get_scrutinee o lhs) g
- val fun_id = get_fun_name_from_app scrut (* This can fail *)
- in
- (* Check if the function is part of the group we are considering *)
- if Redblackset.member (rec_fun_set, fun_id) then
- let
- (* Create a subgoal “odd i = Diverge” *)
- val ret_ty = (hd o snd o dest_type o type_of) scrut
- val g = mk_eq (scrut, mk_diverge_tm ret_ty)
-
- (* Create a subgoal: “?n. odd___P i n”.
-
- It is a bit cumbersome because we have to lookup the proper
- predicate (from “odd” we need to lookup “odd___P”) and we
- may have to perform substitutions... We hack a bit by using
- a conversion to rewrite “odd i” to a term which contains
- the “?n. odd___P i n” we are looking for.
- *)
- val exists_g = (rhs o concl) (PURE_REWRITE_CONV raw_defs scrut)
- val (_, exists_g, _) = TypeBase.dest_case exists_g
- (* The tactic to prove the subgoal *)
- val prove_sg_tac =
- pure_rewrite_tac raw_defs >>
- Cases_on ‘^exists_g’ >> pure_asm_rewrite_tac [] >> fs [] >>
- (* There must only remain the positive case (i.e., “?n. ...”):
- we have a contradiction *)
- exfalso >>
- (* The end of the proof is done by opening the definitions *)
- pop_assum mp_tac >>
- fs pred_defs >> fs [is_diverge_def]
- in
- (SUBGOAL_THEN g assume_tac >- prove_sg_tac >> fs []) (asms, g)
- end
- else all_tac (asms, g) (* Nothing to do *)
- end handle HOL_ERR _ => all_tac (asms, g)
-
- fun prove_one (pred_tm, fun_eq_tm) :
- thm =
- let
- (* “!n. ~even___P i n” *)
- val neg_pred_tm = mk_neg pred_tm
- val pred_tm = mk_forall (fuel_var0, neg_pred_tm)
- val pred_suc_tm = subst [fuel_var0 |-> numSyntax.mk_suc fuel_var0] neg_pred_tm
- val pred_suc_tm = mk_forall (fuel_var0, pred_suc_tm)
-
- (* “even i = even___expand even odd i” *)
- val tm = fun_eq_tm
+ (* Retrieve the body *)
+ val body = (List.last o snd o strip_comb o concl) body_is_valid
- (* Add the implications *)
- val tm = list_mk_imp ([pred_tm, pred_suc_tm], tm)
+ (* Create the term ‘fix body’ *)
+ val fixed_body = mk_icomb (fix_tm, body)
- (* Quantify *)
- val (_, args) = strip_comb (lhs fun_eq_tm)
- val tm = list_mk_forall (args, tm)
+ (* For every function in the group, generate the equation that we will
+ use as definition. In particular:
+ - add the properly injected input ‘x’ to ‘fix body’ (ex.: for ‘nth ls i’
+ we add the input ‘INL (ls, i)’)
+ - wrap ‘fix body x’ into a case disjunction to extract the relevant output
- (* Prove *)
- val prove_tac =
- rpt gen_tac >>
-
- pure_rewrite_tac raw_defs >>
- rpt disch_tac >>
+ For instance, in the case of ‘nth ls i’:
+ {[
+ nth (ls : 't list_t) (i : u32) =
+ case fix nth_body (INL (ls, i)) of
+ | Fail e => Fail e
+ | Diverge => Diverge
+ | Return r =>
+ case r of
+ | INL _ => Fail Failure
+ | INR x => Return x
+ ]}
+ *)
+ fun mk_def_eq (i : int, def_tm : term) : term =
+ let
+ (* Retrieve the lhs of the original definition equation, and in
+ particular the inputs *)
+ val def_lhs = lhs def_tm
+ val args = (snd o strip_comb) def_lhs
- (* This allows to simplify the “?n. even___P i n” *)
- fs [] >>
- (* We don't need the last assumption anymore *)
- last_x_assum ignore_tac >>
+ (* Inject the inputs into the param type *)
+ val input = pairSyntax.list_mk_pair args
+ val input = inject_in_param_sum in_out_tys i true input
- (* Expand *)
- fs pred_defs >> fs [is_diverge_def] >>
- fs expand_defs >>
+ (* Compose*)
+ val def_rhs = mk_comb (fixed_body, input)
- (* We need to be a bit careful when expanding the definitions which use fuel:
- it can make the simplifier loop.
- *)
- pop_assum mp_tac >>
- pure_once_rewrite_tac fuel_defs >>
- rpt disch_tac >> fs [bind_def, case_result_same_eq] >>
+ (* Wrap in the case disjunction *)
+ val def_rhs = mk_case_select_result_sum def_rhs in_out_tys i false
- (* Evaluate all the paths *)
- rpt (rewrite_rec_call >> case_progress >> fs [case_result_same_eq])
+ (* Create the equation *)
+ val def_eq_tm = mk_eq (def_lhs, def_rhs)
in
- save_goal_and_prove (tm, prove_tac)
+ def_eq_tm
end
+ val raw_def_tms = map mk_def_eq (enumerate def_tms)
+
+ (* Generate the definitions *)
+ val raw_defs = map (fn tm => Define ‘^tm’) raw_def_tms
in
- map prove_one term_div_tms
+ raw_defs
end
-(*
-val (pred_tm, fun_eq_tm) = hd term_div_tms
-set_goal ([], tm)
-
-val divergence_thms =
- prove_divergence_thms
- term_div_tms
- fuel_defs
- pred_defs
- raw_defs
- expand_defs
-*)
+(*=============================================================================*
+ *
+ * Prove that the definitions satisfy the target equations
+ *
+ * ============================================================================*)
-(* Prove the final lemmas:
+(* Tactic which makes progress in a proof by making a case disjunction (we use
+ this to explore all the paths in a function body). *)
+fun case_progress (asms, g) =
+ let
+ val scrut = (strip_all_cases_get_scrutinee o lhs) g
+ in Cases_on ‘^scrut’ (asms, g) end
- {[
- !i. even i = even___expand even odd i
- ]}
+(* Prove the final equation, that we will use as definition. *)
+fun prove_def_eq_tac
+ (current_raw_def : thm) (all_raw_defs : thm list) (is_valid : thm)
+ (body_def : thm option) : tactic =
+ let
+ val body_def_thm = case body_def of SOME th => [th] | NONE => []
+ in
+ rpt gen_tac >>
+ (* Expand the definition *)
+ pure_once_rewrite_tac [current_raw_def] >>
+ (* Use the fixed-point equality *)
+ pure_once_rewrite_left_tac [HO_MATCH_MP fix_fixed_eq is_valid] >>
+ (* Expand the body definition *)
+ pure_rewrite_tac body_def_thm >>
+ (* Expand all the definitions from the group *)
+ pure_rewrite_tac all_raw_defs >>
+ (* Explore all the paths - maybe we can be smarter, but this is fast and really easy *)
+ fs [bind_def] >>
+ rpt (case_progress >> fs [])
+ end
- Note that the shape of the theorem is very precise: this helps for the proof.
- Also, by correctly ordering the assumptions, we make sure that by rewriting
- we don't convert one of the two to “T”.
- *)
-fun prove_final_eqs
- (term_div_tms : (term * term) list)
- (termination_thms : thm list)
- (divergence_thms : thm list)
- (raw_defs : thm list)
- : thm list =
+(* Prove the final equations that we will give to the user as definitions *)
+fun prove_def_eqs (body_is_valid : thm) (def_tms : term list) (raw_defs : thm list) : thm list=
let
- fun prove_one ((pred_tm, fun_eq_tm), (termination_thm, divergence_thm)) : thm =
+ val defs_tgt_raw = zip def_tms raw_defs
+ (* Substitute the function variables with the constants introduced in the raw
+ definitions *)
+ fun compute_fsubst (def_tm, raw_def) : {redex: term, residue: term} =
+ let
+ val (fvar, _) = (strip_comb o lhs) def_tm
+ val fconst = (fst o strip_comb o lhs o snd o strip_forall o concl) raw_def
+ in
+ (fvar |-> fconst)
+ end
+ val fsubst = map compute_fsubst defs_tgt_raw
+ val defs_tgt_raw = map (fn (x, y) => (subst fsubst x, y)) defs_tgt_raw
+
+ fun prove_def_eq (def_tm, raw_def) : thm =
let
- val (_, args) = strip_comb (lhs fun_eq_tm)
- val g = list_mk_forall (args, fun_eq_tm)
- (* We make a case disjunction of the subgoal: “exists n. even___P i n” *)
- val exists_g = (rhs o concl) (PURE_REWRITE_CONV raw_defs (lhs fun_eq_tm))
- val (_, exists_g, _) = TypeBase.dest_case exists_g
- val prove_tac =
- rpt gen_tac >>
- Cases_on ‘^exists_g’
- >-( (* Termination *)
- irule termination_thm >> pure_asm_rewrite_tac [])
- (* Divergence *)
- >> irule divergence_thm >> fs []
-
+ (* Quantify the parameters *)
+ val (_, params) = (strip_comb o lhs) def_tm
+ val def_eq_tm = list_mk_forall (params, def_tm)
+ (* Prove *)
+ val def_eq = prove (def_eq_tm, prove_def_eq_tac raw_def raw_defs body_is_valid NONE)
in
- save_goal_and_prove (g, prove_tac)
- end
+ def_eq
+ end
+ val def_eqs = map prove_def_eq defs_tgt_raw
in
- map prove_one (zip term_div_tms (zip termination_thms divergence_thms))
+ def_eqs
end
-(*
-val termination_thm = hd termination_thms
-val divergence_thm = hd divergence_thms
-set_goal ([], g)
-*)
+(*=============================================================================*
+ *
+ * The final DefineDiv function
+ *
+ * ============================================================================*)
-(* The final function: define potentially diverging functions in an error monad *)
fun DefineDiv (def_qt : term quotation) =
let
- (* Parse the definitions.
-
- Example:
- {[
- (even (i : int) : bool result = if i = 0 then Return T else odd (i - 1)) /\
- (odd (i : int) : bool result = if i = 0 then Return F else even (i - 1))
- ]}
- *)
+ (* Parse the definitions *)
val def_tms = (strip_conj o list_mk_conj o rev) (Defn.parse_quote def_qt)
- (* Generate definitions which use some fuel
-
- Example:
- {[
- even___fuel n i =
- case fuel of
- 0 => Diverge
- | SUC fuel =>
- if i = 0 then Return T else odd_fuel (i - 1))
- ]}
- *)
- val fuel_defs = mk_fuel_defs def_tms
-
- (* Generate the predicate definitions.
-
- {[ even___P n i = = ~is_diverge (even___fuel n i) ]}
- *)
- val fuel_def_tms = map (snd o strip_forall o concl) fuel_defs
- val pred_defs = map mk_fuel_predicate_defs (zip def_tms fuel_def_tms)
-
- (* Prove the monotonicity property for the fuel, all at once
-
- *)
- val fuel_mono_thm = prove_fuel_mono pred_defs fuel_defs
-
- (* Prove the individual fuel functions - TODO: update
-
- {[
- !n i. $LEAST (even___P i) <= n ==> even___fuel n i = even___fuel ($LEAST (even___P i)) i
- ]}
- *)
- val least_fuel_mono_thms = prove_least_fuel_mono pred_defs fuel_mono_thm
-
- (*
- {[
- !n i. even___P i n ==> $LEAST (even___P i) <= n
- ]}
- *)
- val least_pred_thms = prove_least_pred_thms pred_defs
-
- (*
- {[
- !n i. even___P i n ==> even___P i ($LEAST (even___P i))
- ]}
- *)
- val pred_n_imp_pred_least_thms = prove_pred_n_imp_pred_least_thms pred_defs
-
- (*
- "Raw" definitions:
-
- {[
- even i = if (?n. even___P i n) then even___P ($LEAST (even___P i)) i else Diverge
- ]}
- *)
- val raw_defs = define_raw_defs def_tms pred_defs fuel_defs
-
- (*
- !n i. even___P i n ==> even___fuel n i = even i
- *)
- val pred_imp_fuel_eq_raw_defs =
- prove_pred_imp_fuel_eq_raw_defs
- pred_defs fuel_def_tms least_fuel_mono_thms
- least_pred_thms pred_n_imp_pred_least_thms raw_defs
-
- (* "Expand" definitions *)
- val expand_defs = gen_expand_defs def_tms
-
- (* Small utility *)
- val term_div_tms = mk_termination_diverge_tms def_tms pred_defs raw_defs expand_defs
-
- (* Termination theorems *)
- val termination_thms =
- prove_termination_thms term_div_tms fuel_defs pred_defs
- raw_defs expand_defs pred_n_imp_pred_least_thms pred_imp_fuel_eq_raw_defs
+ (* Compute the names and the input/output types of the functions *)
+ fun compute_names_in_out_tys (tm : term) : string * (hol_type * hol_type) =
+ let
+ val app = lhs tm
+ val name = (fst o dest_var o fst o strip_comb) app
+ val out_ty = dest_result (type_of app)
+ val input_tys = pairSyntax.list_mk_prod (map type_of ((snd o strip_comb) app))
+ in
+ (name, (input_tys, out_ty))
+ end
+ val (fnames, in_out_tys) = unzip (map compute_names_in_out_tys def_tms)
- (* Divergence theorems *)
- val divergence_thms =
- prove_divergence_thms term_div_tms fuel_defs pred_defs raw_defs expand_defs
+ (* Generate the body to give to the fixed-point operator *)
+ val body = mk_body fnames in_out_tys def_tms
- (* Final theorems:
+ (* Prove that the body satisfies the validity property required by the fixed point *)
+ val body_is_valid = prove_body_is_valid body
+
+ (* Generate the definitions for the various functions by using the fixed point
+ and the body *)
+ val raw_defs = mk_raw_defs in_out_tys def_tms body_is_valid
- {[
- ∀i. even i = even___E even odd i,
- ⊢ ∀i. odd i = odd___E even odd i
- ]}
- *)
- val final_eqs = prove_final_eqs term_div_tms termination_thms divergence_thms raw_defs
- val final_eqs = map (PURE_REWRITE_RULE expand_defs) final_eqs
+ (* Prove the final equations *)
+ val def_eqs = prove_def_eqs body_is_valid def_tms raw_defs
in
- (* We return the final equations, which act as rewriting theorems *)
- final_eqs
+ def_eqs
end
end