summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-01-31 22:31:43 +0100
committerSon HO2023-06-04 21:54:38 +0200
commit232219b6465a84ebd31b6821c0950c7c380d1824 (patch)
tree05c284955c5dd2cfcf6414d68e8f3350ef938b6b
parentbbe4a8b234d183e36c157dbc6b9900214e405a52 (diff)
Finish a first working version of divDefLib.sml
-rw-r--r--backends/hol4/divDefLib.sml376
1 files changed, 347 insertions, 29 deletions
diff --git a/backends/hol4/divDefLib.sml b/backends/hol4/divDefLib.sml
index bfd36af1..ef18d14f 100644
--- a/backends/hol4/divDefLib.sml
+++ b/backends/hol4/divDefLib.sml
@@ -15,30 +15,21 @@ raw_def_thm -> raw_def
fuel_defs_thm -> fuel_defs (and split the theorem)
*)
-(* TotalDefn.Define *)
-
-(*
-Datatype:
- list_t =
- ListCons 't list_t
- | ListNil
-End
-*)
-
(* TODO: move *)
fun list_mk_arrow (tys : hol_type list) (ret_ty : hol_type) : hol_type =
foldr (fn (ty, aty) => ty --> aty) ret_ty tys
-(*“test (x : bool) = (x <> F)” *)
-
+(*
val def_qt = ‘
(even (i : int) : bool result =
if i = 0 then Return T else odd (i - 1)) /\
(odd (i : int) : bool result =
if i = 0 then Return F else even (i - 1))
+
val def_tms = (strip_conj o list_mk_conj o rev) (Defn.parse_quote def_qt)
val def_tm = hd def_tms
+*)
(* Small utilities *)
val current_goal : term option ref = ref NONE
@@ -87,10 +78,13 @@ val fuel_var1 = mk_var ("$m", “:num”) (* TODO: name collisions *)
val fuel_vars_le = “^fuel_var0 <= ^fuel_var1”
val fuel_predicate_suffix = "___P" (* TODO: name collisions *)
+val expand_suffix = "___E" (* TODO: name collisions *)
val bool_ty = “:bool”
val alpha_tyvar : hol_type = “:'a”
+val beta_tyvar : hol_type = “:'b”
+
val is_diverge_def = Define ‘
is_diverge (r: 'a result) : bool = case r of Diverge => T | _ => F’
val is_diverge_tm = “is_diverge: 'a result -> bool”
@@ -101,6 +95,8 @@ val le_tm = (fst o strip_comb) “x:num <= y:num”
val true_tm = “T”
val false_tm = “F”
+val measure_tm = “measure: ('a -> num) -> 'a -> 'a -> bool”
+
fun mk_diverge_tm (ty : hol_type) : term =
let
val diverge_ty = mk_thy_type {Thy="primitives", Tyop="result", Args = [ty] }
@@ -109,6 +105,143 @@ fun mk_diverge_tm (ty : hol_type) : term =
diverge_tm
end
+(*
+(* TODO: move *)
+fun strip_pair_type (ty : hol_type) : hol_type list =
+ let
+ val {Args=args, Thy=thy, Tyop=tyop} = dest_thy_type ty
+ in
+ if thy = "pair" andalso tyop = "prod" then
+ case args of
+ [x, y] => x :: strip_pair_type y
+ | _ => failwith "Unexpected"
+ else [ty]
+ end
+ handle HOL_ERR _ => [ty]
+
+fun list_mk_pair_type (tys : hol_type list) : hol_type =
+ case tys of
+ [] => failwith "Unexpected"
+ | [x] => x
+ | x :: tys =>
+ mk_thy_type {Args = [x, list_mk_pair_type tys], Thy="pair", Tyop="prod"}
+
+fun strip_sum_type (ty : hol_type) : hol_type list =
+ let
+ val {Args=args, Thy=thy, Tyop=tyop} = dest_thy_type ty
+ in
+ if thy = "sum" andalso tyop = "sum" then
+ case args of
+ [x, y] => x :: strip_sum_type y
+ | _ => failwith "Unexpected"
+ else [ty]
+ end
+ handle HOL_ERR _ => [ty]
+
+fun list_mk_sum_type (tys : hol_type list) : hol_type =
+ case tys of
+ [] => failwith "Unexpected"
+ | [x] => x
+ | x :: tys =>
+ mk_thy_type {Args = [x, list_mk_sum_type tys], Thy="sum", Tyop="sum"}
+*)
+
+
+(*
+val ty = “: ('a # 'b) # num # int”
+strip_pair_type ty
+list_mk_pair_type (strip_pair_type ty)
+
+val ty = “: ('a + 'b) + num + int”
+strip_sum_type ty
+list_mk_sum_type (strip_sum_type ty)
+
+val ty = “: (num # 'a # bool) + (num # int # bool) + (num # 'a)”
+*)
+
+(* Small utility: we sometimes need to generate a termination measure for
+ the fuel definitions.
+
+ We derive a measure for a type which is simply the sum of the tuples
+ of the input types of the functions.
+
+ For instance, for even and odd we have:
+ {[
+ even___fuel : num -> int -> bool result
+ odd___fuel : num -> int -> bool result
+ ]}
+
+ So the type would be:
+ {[
+ (num # int) + (num # int)
+ ]}
+
+ Note that generally speaking we expect a type of the shape (the “:num”
+ on the left is for the fuel):
+ {[
+ (num # ...) + (num # ...) + ... + (num # ...)
+ ]}
+
+ The decreasing measure is simply given by a function which matches over
+ its argument to return the fuel, whatever the case.
+ *)
+fun mk_termination_measure_from_ty (ty : hol_type) : term =
+ let
+ val dtys = map pairSyntax.strip_prod (sumSyntax.strip_sum ty)
+ (* For every tuple, create a match to extract the num *)
+ fun mk_case_of_tuple (tys : hol_type list) : (term * term) =
+ case tys of
+ [] => failwith "mk_termination_measure_from_ty: empty list of types"
+ | [num_ty] =>
+ (* No need for a case *)
+ let val var = genvar num_ty in (var, var) end
+ | num_ty :: rem_tys =>
+ let
+ val scrut_var = genvar (pairSyntax.list_mk_prod tys)
+ val var = genvar num_ty
+ val rem_var = genvar (pairSyntax.list_mk_prod rem_tys)
+ val pats = [(pairSyntax.mk_pair (var, rem_var), var)]
+ val case_tm = TypeBase.mk_case (scrut_var, pats)
+ in
+ (scrut_var, case_tm)
+ end
+ val tuple_cases = map mk_case_of_tuple dtys
+
+ (* For every sum, create a match to extract one of the tuples *)
+ fun mk_sum_case ((tuple_var, tuple_case), (nvar, case_end)) =
+ let
+ val left_pat = sumSyntax.mk_inl (tuple_var, type_of nvar)
+ val right_pat = sumSyntax.mk_inr (nvar, type_of tuple_var)
+ val scrut = genvar (sumSyntax.mk_sum (type_of tuple_var, type_of nvar))
+ val pats = [(left_pat, tuple_case), (right_pat, case_end)]
+ val case_tm = TypeBase.mk_case (scrut, pats)
+ in
+ (scrut, case_tm)
+ end
+ val tuple_cases = rev tuple_cases
+ val (nvar, case_end) = hd tuple_cases
+ val tuple_cases = tl tuple_cases
+ val (scrut, case_tm) = foldl mk_sum_case (nvar, case_end) tuple_cases
+
+ (* Create the function *)
+ val abs_tm = mk_abs (scrut, case_tm)
+
+ (* Add the “measure term” *)
+ val tm = inst [alpha_tyvar |-> type_of scrut] measure_tm
+ val tm = mk_comb (tm, abs_tm)
+ in
+ tm
+ end
+
+(*
+val ty = “: (num # 'a) + (num # 'b) + (num # 'c)”
+
+val tys = hd dtys
+val num_ty::rem_tys = tys
+
+val (tuple_var, tuple_case) = hd tuple_cases
+*)
+
fun mk_fuel_defs (def_tms : term list) : thm =
let
(* Retrieve the identifiers.
@@ -116,17 +249,16 @@ fun mk_fuel_defs (def_tms : term list) : thm =
Ex.: def_tm = “even (n : int) : bool result = if i = 0 then Return T else odd (i - 1))”
We want to retrive: id = “even”
*)
- val app = lhs def_tm
val ids = map (fst o strip_comb o lhs) def_tms
(* In the definitions, replace the identifiers by new identifiers which use
fuel.
Ex.:
def_fuel_tm = “
- even_fuel (fuel : nat) (n : int) : result bool =
+ even___fuel (fuel : nat) (n : int) : result bool =
case fuel of 0 => Diverge
- | SUC fuel =>
- if i = 0 then Return T else odd_fuel (i - 1))”
+ | SUC fuel' =>
+ if i = 0 then Return T else odd_fuel fuel' (i - 1))”
*)
fun mk_fuel_id (id : term) : term =
let
@@ -138,16 +270,25 @@ fun mk_fuel_defs (def_tms : term list) : thm =
in fuel_id end
val fuel_ids = map mk_fuel_id ids
- val fuel_ids_with_fuel = map (fn id => mk_comb (id, fuel_var)) fuel_ids
+ val fuel_ids_with_fuel0 = map (fn id => mk_comb (id, fuel_var0)) fuel_ids
+ val fuel_ids_with_fuel1 = map (fn id => mk_comb (id, fuel_var1)) fuel_ids
(* Recurse through the terms and replace the calls *)
- val rwr_thms = map (ASSUME o mk_eq) (zip ids fuel_ids_with_fuel)
- val fuel_tms = map (rhs o concl o (PURE_REWRITE_CONV rwr_thms)) def_tms
+ val rwr_thms0 = map (ASSUME o mk_eq) (zip ids fuel_ids_with_fuel0)
+ val rwr_thms1 = map (ASSUME o mk_eq) (zip ids fuel_ids_with_fuel1)
+
+ fun mk_fuel_tm (def_tm : term) : term =
+ let
+ val (tm0, tm1) = dest_eq def_tm
+ val tm0 = (rhs o concl o (PURE_REWRITE_CONV rwr_thms0)) tm0
+ val tm1 = (rhs o concl o (PURE_REWRITE_CONV rwr_thms1)) tm1
+ in mk_eq (tm0, tm1) end
+ val fuel_tms = map mk_fuel_tm def_tms
(* Add the case over the fuel *)
fun add_fuel_case (tm : term) : term =
let
- val (app, body) = dest_eq tm
+ val (f, body) = dest_eq tm
(* Create the “Diverge” term with the proper type *)
val body_ty = type_of body
val return_ty =
@@ -155,20 +296,47 @@ fun mk_fuel_defs (def_tms : term list) : thm =
| _ => failwith "unexpected"
val diverge_tm = mk_diverge_tm return_ty
(* Create the “SUC fuel” term *)
- val suc_tm = mk_comb (num_suc_tm, fuel_var)
- val fuel_tm = TypeBase.mk_case (fuel_var, [(num_zero_tm, diverge_tm), (suc_tm, body)])
- in mk_eq (app, fuel_tm) end
+ val suc_tm = mk_comb (num_suc_tm, fuel_var1)
+ val fuel_tm =
+ TypeBase.mk_case (fuel_var0, [(num_zero_tm, diverge_tm), (suc_tm, body)])
+ in mk_eq (f, fuel_tm) end
val fuel_tms = map add_fuel_case fuel_tms
(* Define the auxiliary definitions which use fuel *)
val fuel_defs_conj = list_mk_conj fuel_tms
+ (* The definition name *)
+ val def_name = (fst o dest_var o hd) fuel_ids
+ (* The tactic to prove the termination *)
+ val rty = ref “:bool”
+ fun prove_termination_tac (asms, g) =
+ let
+ val _ = print_term g
+ val r_tm = (fst o dest_exists) g
+ val _ = rty := type_of r_tm
+ val ty = (hd o snd o dest_type) (!rty)
+ val m_tm = mk_termination_measure_from_ty ty
+ in
+ WF_REL_TAC ‘^m_tm’ (asms, g)
+ end
(* Define the fuel definitions *)
- val fuel_defs_thm = Define ‘^fuel_defs_conj’
+ (*
+ val temp_def = Hol_defn def_name ‘^fuel_defs_conj’
+ Defn.tgoal temp_def
+ *)
+ val fuel_defs_thm = tDefine def_name ‘^fuel_defs_conj’ prove_termination_tac
in
fuel_defs_thm
end
(*
+val (asms, g) = top_goal ()
+
+val rty = ref “:num”
+val ty = !rty
+(hd o snd o dest_type) ty
+*)
+
+(*
val (fuel_tms, fuel_defs_thm) = mk_fuel_defs def_tms
val fuel_def_tms = map (snd o strip_forall) ((strip_conj o concl) fuel_defs_thm)
val (def_tm, fuel_def_tm) = hd (zip def_tms fuel_def_tms)
@@ -215,7 +383,7 @@ fun case_progress (asms, g) =
val scrut = (strip_all_cases_get_scrutinee o lhs) g
in Cases_on ‘^scrut’ (asms, g) end
-(* Tactic to prove the fuel monotonicity theorem *)
+(* Tactic to prove the fuel monotonicity theorem - TODO: move below *)
fun prove_fuel_mono_tac (pred_def_thms : thm list) (fuel_defs_thm : thm) =
Induct_on ‘^fuel_var0’ >-(
(* The ___P predicates are false: n is 0 *)
@@ -269,8 +437,6 @@ fun prove_fuel_mono (pred_def_thms : thm list) (fuel_defs_thm : thm) : thm =
val fuel_eq_tms = map mk_fuel_eq_tm pred_fuel_tms
(* Create the conjunction *)
val fuel_eq_tms = list_mk_conj fuel_eq_tms
-(* (* Add the “n <= m ==> ...” implication *)
- val fuel_eq_tms = mk_imp (fuel_vars_le, fuel_eq_tms) *)
(* Qantify over the fuels *)
val fuel_eq_tms = list_mk_forall ([fuel_var0, fuel_var1], fuel_eq_tms)
in
@@ -587,6 +753,7 @@ val expand_defs = gen_expand_defs def_tms
*)
fun mk_termination_diverge_tms
+ (def_tms : term list)
(pred_def_thms : thm list)
(raw_def_thms : thm list)
(expand_defs : thm list) :
@@ -694,7 +861,7 @@ fun prove_termination_thms
(* Add the implication *)
val tm = mk_imp (pred_tm, tm)
(* Quantify *)
- val (_, args) = strip_comb fun_tm
+ val (_, args) = strip_comb (lhs fun_eq_tm)
val tm = list_mk_forall (args, tm)
(* Prove *)
@@ -838,7 +1005,7 @@ fun prove_divergence_thms
val tm = list_mk_imp ([pred_tm, pred_suc_tm], tm)
(* Quantify *)
- val (_, args) = strip_comb fun_tm
+ val (_, args) = strip_comb (lhs fun_eq_tm)
val tm = list_mk_forall (args, tm)
(* Prove *)
@@ -888,3 +1055,154 @@ set_goal ([], tm)
val (asms, g) = top_goal ()
*)
+
+(* Prove the final lemmas:
+
+ {[
+ !i. even i = even___expand even odd i
+ ]}
+
+ Note that the shape of the theorem is very precise: this helps for the proof.
+ Also, by correctly ordering the assumptions, we make sure that by rewriting
+ we don't convert one of the two to “T”.
+ *)
+fun prove_final_eqs
+ (term_div_tms : (term * term) list)
+ (termination_thms : thm list)
+ (divergence_thms : thm list)
+ (raw_def_thms : thm list)
+ : thm list =
+ let
+ fun prove_one ((pred_tm, fun_eq_tm), (termination_thm, divergence_thm)) : thm =
+ let
+ val (_, args) = strip_comb (lhs fun_eq_tm)
+ val g = list_mk_forall (args, fun_eq_tm)
+ (* We make a case disjunction of the subgoal: “exists n. even___P i n” *)
+ val exists_g = (rhs o concl) (PURE_REWRITE_CONV raw_def_thms (lhs fun_eq_tm))
+ val (_, exists_g, _) = TypeBase.dest_case exists_g
+ val prove_tac =
+ rpt gen_tac >>
+ Cases_on ‘^exists_g’
+ >-( (* Termination *)
+ irule termination_thm >> pure_asm_rewrite_tac [])
+ (* Divergence *)
+ >> irule divergence_thm >> fs []
+
+ in
+ save_goal_and_prove (g, prove_tac)
+ end
+ in
+ map prove_one (zip term_div_tms (zip termination_thms divergence_thms))
+ end
+
+(*
+val termination_thm = hd termination_thms
+val divergence_thm = hd divergence_thms
+set_goal ([], g)
+*)
+
+(* The final function: define potentially diverging functions in an error monad *)
+fun DefineDiv (def_qt : term quotation) =
+ let
+ (* Parse the definitions.
+
+ Example:
+ {[
+ (even (i : int) : bool result = if i = 0 then Return T else odd (i - 1)) /\
+ (odd (i : int) : bool result = if i = 0 then Return F else even (i - 1))
+ ]}
+ *)
+ val def_tms = (strip_conj o list_mk_conj o rev) (Defn.parse_quote def_qt)
+
+ (* Generate definitions which use some fuel
+
+ Example:
+ {[
+ even___fuel n i =
+ case fuel of
+ 0 => Diverge
+ | SUC fuel =>
+ if i = 0 then Return T else odd_fuel (i - 1))
+ ]}
+ *)
+ (* TODO: list of theorems *)
+ val fuel_defs_thm = mk_fuel_defs def_tms
+
+ (* Generate the predicate definitions.
+
+ {[ even___P n i = = ~is_diverge (even___fuel n i) ]}
+ *)
+ val fuel_def_tms = map (snd o strip_forall) ((strip_conj o concl) fuel_defs_thm)
+ val pred_def_thms = map mk_fuel_predicate_defs (zip def_tms fuel_def_tms)
+
+ (* Prove the monotonicity property for the fuel, all at once
+
+ *)
+ val fuel_mono_thm = prove_fuel_mono pred_def_thms fuel_defs_thm
+
+ (* Prove the individual fuel functions - TODO: update
+
+ {[
+ !n i. $LEAST (even___P i) <= n ==> even___fuel n i = even___fuel ($LEAST (even___P i)) i
+ ]}
+ *)
+ val least_fuel_mono_thms = prove_least_fuel_mono pred_def_thms fuel_mono_thm
+
+ (*
+ {[
+ !n i. even___P i n ==> $LEAST (even___P i) <= n
+ ]}
+ *)
+ val least_pred_thms = prove_least_pred_thms pred_def_thms
+
+ (*
+ {[
+ !n i. even___P i n ==> even___P i ($LEAST (even___P i))
+ ]}
+ *)
+ val pred_n_imp_pred_least_thms = prove_pred_n_imp_pred_least_thms pred_def_thms
+
+ (*
+ "Raw" definitions:
+
+ {[
+ even i = if (?n. even___P i n) then even___P ($LEAST (even___P i)) i else Diverge
+ ]}
+ *)
+ val raw_def_thms = define_raw_defs def_tms pred_def_thms fuel_defs_thm
+
+ (*
+ !n i. even___P i n ==> even___fuel n i = even i
+ *)
+ val pred_imp_fuel_eq_raw_def_thms =
+ prove_pred_imp_fuel_eq_raw_def_thms
+ pred_def_thms fuel_def_tms least_fuel_mono_thms
+ least_pred_thms pred_n_imp_pred_least_thms raw_def_thms
+
+ (* "Expand" definitions *)
+ val expand_defs = gen_expand_defs def_tms
+
+ (* Small utility *)
+ val term_div_tms = mk_termination_diverge_tms def_tms pred_def_thms raw_def_thms expand_defs
+
+ (* Termination theorems *)
+ val termination_thms =
+ prove_termination_thms term_div_tms fuel_defs_thm pred_def_thms
+ raw_def_thms expand_defs pred_n_imp_pred_least_thms pred_imp_fuel_eq_raw_def_thms
+
+ (* Divergence theorems *)
+ val divergence_thms =
+ prove_divergence_thms term_div_tms fuel_defs_thm pred_def_thms raw_def_thms expand_defs
+
+ (* Final theorems:
+
+ {[
+ ∀i. even i = even___E even odd i,
+ ⊢ ∀i. odd i = odd___E even odd i
+ ]}
+ *)
+ val final_eqs = prove_final_eqs term_div_tms termination_thms divergence_thms raw_def_thms
+ in
+ (* We return the final equations, which act as rewriting theorems *)
+ final_eqs
+ end