diff options
author | Son Ho | 2023-01-31 22:31:43 +0100 |
---|---|---|
committer | Son HO | 2023-06-04 21:54:38 +0200 |
commit | 232219b6465a84ebd31b6821c0950c7c380d1824 (patch) | |
tree | 05c284955c5dd2cfcf6414d68e8f3350ef938b6b | |
parent | bbe4a8b234d183e36c157dbc6b9900214e405a52 (diff) |
Finish a first working version of divDefLib.sml
-rw-r--r-- | backends/hol4/divDefLib.sml | 376 |
1 files changed, 347 insertions, 29 deletions
diff --git a/backends/hol4/divDefLib.sml b/backends/hol4/divDefLib.sml index bfd36af1..ef18d14f 100644 --- a/backends/hol4/divDefLib.sml +++ b/backends/hol4/divDefLib.sml @@ -15,30 +15,21 @@ raw_def_thm -> raw_def fuel_defs_thm -> fuel_defs (and split the theorem) *) -(* TotalDefn.Define *) - -(* -Datatype: - list_t = - ListCons 't list_t - | ListNil -End -*) - (* TODO: move *) fun list_mk_arrow (tys : hol_type list) (ret_ty : hol_type) : hol_type = foldr (fn (ty, aty) => ty --> aty) ret_ty tys -(*“test (x : bool) = (x <> F)” *) - +(* val def_qt = ‘ (even (i : int) : bool result = if i = 0 then Return T else odd (i - 1)) /\ (odd (i : int) : bool result = if i = 0 then Return F else even (i - 1)) ’ + val def_tms = (strip_conj o list_mk_conj o rev) (Defn.parse_quote def_qt) val def_tm = hd def_tms +*) (* Small utilities *) val current_goal : term option ref = ref NONE @@ -87,10 +78,13 @@ val fuel_var1 = mk_var ("$m", “:num”) (* TODO: name collisions *) val fuel_vars_le = “^fuel_var0 <= ^fuel_var1” val fuel_predicate_suffix = "___P" (* TODO: name collisions *) +val expand_suffix = "___E" (* TODO: name collisions *) val bool_ty = “:bool” val alpha_tyvar : hol_type = “:'a” +val beta_tyvar : hol_type = “:'b” + val is_diverge_def = Define ‘ is_diverge (r: 'a result) : bool = case r of Diverge => T | _ => F’ val is_diverge_tm = “is_diverge: 'a result -> bool” @@ -101,6 +95,8 @@ val le_tm = (fst o strip_comb) “x:num <= y:num” val true_tm = “T” val false_tm = “F” +val measure_tm = “measure: ('a -> num) -> 'a -> 'a -> bool” + fun mk_diverge_tm (ty : hol_type) : term = let val diverge_ty = mk_thy_type {Thy="primitives", Tyop="result", Args = [ty] } @@ -109,6 +105,143 @@ fun mk_diverge_tm (ty : hol_type) : term = diverge_tm end +(* +(* TODO: move *) +fun strip_pair_type (ty : hol_type) : hol_type list = + let + val {Args=args, Thy=thy, Tyop=tyop} = dest_thy_type ty + in + if thy = "pair" andalso tyop = "prod" then + case args of + [x, y] => x :: strip_pair_type y + | _ => failwith "Unexpected" + else [ty] + end + handle HOL_ERR _ => [ty] + +fun list_mk_pair_type (tys : hol_type list) : hol_type = + case tys of + [] => failwith "Unexpected" + | [x] => x + | x :: tys => + mk_thy_type {Args = [x, list_mk_pair_type tys], Thy="pair", Tyop="prod"} + +fun strip_sum_type (ty : hol_type) : hol_type list = + let + val {Args=args, Thy=thy, Tyop=tyop} = dest_thy_type ty + in + if thy = "sum" andalso tyop = "sum" then + case args of + [x, y] => x :: strip_sum_type y + | _ => failwith "Unexpected" + else [ty] + end + handle HOL_ERR _ => [ty] + +fun list_mk_sum_type (tys : hol_type list) : hol_type = + case tys of + [] => failwith "Unexpected" + | [x] => x + | x :: tys => + mk_thy_type {Args = [x, list_mk_sum_type tys], Thy="sum", Tyop="sum"} +*) + + +(* +val ty = “: ('a # 'b) # num # int” +strip_pair_type ty +list_mk_pair_type (strip_pair_type ty) + +val ty = “: ('a + 'b) + num + int” +strip_sum_type ty +list_mk_sum_type (strip_sum_type ty) + +val ty = “: (num # 'a # bool) + (num # int # bool) + (num # 'a)” +*) + +(* Small utility: we sometimes need to generate a termination measure for + the fuel definitions. + + We derive a measure for a type which is simply the sum of the tuples + of the input types of the functions. + + For instance, for even and odd we have: + {[ + even___fuel : num -> int -> bool result + odd___fuel : num -> int -> bool result + ]} + + So the type would be: + {[ + (num # int) + (num # int) + ]} + + Note that generally speaking we expect a type of the shape (the “:num” + on the left is for the fuel): + {[ + (num # ...) + (num # ...) + ... + (num # ...) + ]} + + The decreasing measure is simply given by a function which matches over + its argument to return the fuel, whatever the case. + *) +fun mk_termination_measure_from_ty (ty : hol_type) : term = + let + val dtys = map pairSyntax.strip_prod (sumSyntax.strip_sum ty) + (* For every tuple, create a match to extract the num *) + fun mk_case_of_tuple (tys : hol_type list) : (term * term) = + case tys of + [] => failwith "mk_termination_measure_from_ty: empty list of types" + | [num_ty] => + (* No need for a case *) + let val var = genvar num_ty in (var, var) end + | num_ty :: rem_tys => + let + val scrut_var = genvar (pairSyntax.list_mk_prod tys) + val var = genvar num_ty + val rem_var = genvar (pairSyntax.list_mk_prod rem_tys) + val pats = [(pairSyntax.mk_pair (var, rem_var), var)] + val case_tm = TypeBase.mk_case (scrut_var, pats) + in + (scrut_var, case_tm) + end + val tuple_cases = map mk_case_of_tuple dtys + + (* For every sum, create a match to extract one of the tuples *) + fun mk_sum_case ((tuple_var, tuple_case), (nvar, case_end)) = + let + val left_pat = sumSyntax.mk_inl (tuple_var, type_of nvar) + val right_pat = sumSyntax.mk_inr (nvar, type_of tuple_var) + val scrut = genvar (sumSyntax.mk_sum (type_of tuple_var, type_of nvar)) + val pats = [(left_pat, tuple_case), (right_pat, case_end)] + val case_tm = TypeBase.mk_case (scrut, pats) + in + (scrut, case_tm) + end + val tuple_cases = rev tuple_cases + val (nvar, case_end) = hd tuple_cases + val tuple_cases = tl tuple_cases + val (scrut, case_tm) = foldl mk_sum_case (nvar, case_end) tuple_cases + + (* Create the function *) + val abs_tm = mk_abs (scrut, case_tm) + + (* Add the “measure term” *) + val tm = inst [alpha_tyvar |-> type_of scrut] measure_tm + val tm = mk_comb (tm, abs_tm) + in + tm + end + +(* +val ty = “: (num # 'a) + (num # 'b) + (num # 'c)” + +val tys = hd dtys +val num_ty::rem_tys = tys + +val (tuple_var, tuple_case) = hd tuple_cases +*) + fun mk_fuel_defs (def_tms : term list) : thm = let (* Retrieve the identifiers. @@ -116,17 +249,16 @@ fun mk_fuel_defs (def_tms : term list) : thm = Ex.: def_tm = “even (n : int) : bool result = if i = 0 then Return T else odd (i - 1))” We want to retrive: id = “even” *) - val app = lhs def_tm val ids = map (fst o strip_comb o lhs) def_tms (* In the definitions, replace the identifiers by new identifiers which use fuel. Ex.: def_fuel_tm = “ - even_fuel (fuel : nat) (n : int) : result bool = + even___fuel (fuel : nat) (n : int) : result bool = case fuel of 0 => Diverge - | SUC fuel => - if i = 0 then Return T else odd_fuel (i - 1))” + | SUC fuel' => + if i = 0 then Return T else odd_fuel fuel' (i - 1))” *) fun mk_fuel_id (id : term) : term = let @@ -138,16 +270,25 @@ fun mk_fuel_defs (def_tms : term list) : thm = in fuel_id end val fuel_ids = map mk_fuel_id ids - val fuel_ids_with_fuel = map (fn id => mk_comb (id, fuel_var)) fuel_ids + val fuel_ids_with_fuel0 = map (fn id => mk_comb (id, fuel_var0)) fuel_ids + val fuel_ids_with_fuel1 = map (fn id => mk_comb (id, fuel_var1)) fuel_ids (* Recurse through the terms and replace the calls *) - val rwr_thms = map (ASSUME o mk_eq) (zip ids fuel_ids_with_fuel) - val fuel_tms = map (rhs o concl o (PURE_REWRITE_CONV rwr_thms)) def_tms + val rwr_thms0 = map (ASSUME o mk_eq) (zip ids fuel_ids_with_fuel0) + val rwr_thms1 = map (ASSUME o mk_eq) (zip ids fuel_ids_with_fuel1) + + fun mk_fuel_tm (def_tm : term) : term = + let + val (tm0, tm1) = dest_eq def_tm + val tm0 = (rhs o concl o (PURE_REWRITE_CONV rwr_thms0)) tm0 + val tm1 = (rhs o concl o (PURE_REWRITE_CONV rwr_thms1)) tm1 + in mk_eq (tm0, tm1) end + val fuel_tms = map mk_fuel_tm def_tms (* Add the case over the fuel *) fun add_fuel_case (tm : term) : term = let - val (app, body) = dest_eq tm + val (f, body) = dest_eq tm (* Create the “Diverge” term with the proper type *) val body_ty = type_of body val return_ty = @@ -155,20 +296,47 @@ fun mk_fuel_defs (def_tms : term list) : thm = | _ => failwith "unexpected" val diverge_tm = mk_diverge_tm return_ty (* Create the “SUC fuel” term *) - val suc_tm = mk_comb (num_suc_tm, fuel_var) - val fuel_tm = TypeBase.mk_case (fuel_var, [(num_zero_tm, diverge_tm), (suc_tm, body)]) - in mk_eq (app, fuel_tm) end + val suc_tm = mk_comb (num_suc_tm, fuel_var1) + val fuel_tm = + TypeBase.mk_case (fuel_var0, [(num_zero_tm, diverge_tm), (suc_tm, body)]) + in mk_eq (f, fuel_tm) end val fuel_tms = map add_fuel_case fuel_tms (* Define the auxiliary definitions which use fuel *) val fuel_defs_conj = list_mk_conj fuel_tms + (* The definition name *) + val def_name = (fst o dest_var o hd) fuel_ids + (* The tactic to prove the termination *) + val rty = ref “:bool” + fun prove_termination_tac (asms, g) = + let + val _ = print_term g + val r_tm = (fst o dest_exists) g + val _ = rty := type_of r_tm + val ty = (hd o snd o dest_type) (!rty) + val m_tm = mk_termination_measure_from_ty ty + in + WF_REL_TAC ‘^m_tm’ (asms, g) + end (* Define the fuel definitions *) - val fuel_defs_thm = Define ‘^fuel_defs_conj’ + (* + val temp_def = Hol_defn def_name ‘^fuel_defs_conj’ + Defn.tgoal temp_def + *) + val fuel_defs_thm = tDefine def_name ‘^fuel_defs_conj’ prove_termination_tac in fuel_defs_thm end (* +val (asms, g) = top_goal () + +val rty = ref “:num” +val ty = !rty +(hd o snd o dest_type) ty +*) + +(* val (fuel_tms, fuel_defs_thm) = mk_fuel_defs def_tms val fuel_def_tms = map (snd o strip_forall) ((strip_conj o concl) fuel_defs_thm) val (def_tm, fuel_def_tm) = hd (zip def_tms fuel_def_tms) @@ -215,7 +383,7 @@ fun case_progress (asms, g) = val scrut = (strip_all_cases_get_scrutinee o lhs) g in Cases_on ‘^scrut’ (asms, g) end -(* Tactic to prove the fuel monotonicity theorem *) +(* Tactic to prove the fuel monotonicity theorem - TODO: move below *) fun prove_fuel_mono_tac (pred_def_thms : thm list) (fuel_defs_thm : thm) = Induct_on ‘^fuel_var0’ >-( (* The ___P predicates are false: n is 0 *) @@ -269,8 +437,6 @@ fun prove_fuel_mono (pred_def_thms : thm list) (fuel_defs_thm : thm) : thm = val fuel_eq_tms = map mk_fuel_eq_tm pred_fuel_tms (* Create the conjunction *) val fuel_eq_tms = list_mk_conj fuel_eq_tms -(* (* Add the “n <= m ==> ...” implication *) - val fuel_eq_tms = mk_imp (fuel_vars_le, fuel_eq_tms) *) (* Qantify over the fuels *) val fuel_eq_tms = list_mk_forall ([fuel_var0, fuel_var1], fuel_eq_tms) in @@ -587,6 +753,7 @@ val expand_defs = gen_expand_defs def_tms *) fun mk_termination_diverge_tms + (def_tms : term list) (pred_def_thms : thm list) (raw_def_thms : thm list) (expand_defs : thm list) : @@ -694,7 +861,7 @@ fun prove_termination_thms (* Add the implication *) val tm = mk_imp (pred_tm, tm) (* Quantify *) - val (_, args) = strip_comb fun_tm + val (_, args) = strip_comb (lhs fun_eq_tm) val tm = list_mk_forall (args, tm) (* Prove *) @@ -838,7 +1005,7 @@ fun prove_divergence_thms val tm = list_mk_imp ([pred_tm, pred_suc_tm], tm) (* Quantify *) - val (_, args) = strip_comb fun_tm + val (_, args) = strip_comb (lhs fun_eq_tm) val tm = list_mk_forall (args, tm) (* Prove *) @@ -888,3 +1055,154 @@ set_goal ([], tm) val (asms, g) = top_goal () *) + +(* Prove the final lemmas: + + {[ + !i. even i = even___expand even odd i + ]} + + Note that the shape of the theorem is very precise: this helps for the proof. + Also, by correctly ordering the assumptions, we make sure that by rewriting + we don't convert one of the two to “T”. + *) +fun prove_final_eqs + (term_div_tms : (term * term) list) + (termination_thms : thm list) + (divergence_thms : thm list) + (raw_def_thms : thm list) + : thm list = + let + fun prove_one ((pred_tm, fun_eq_tm), (termination_thm, divergence_thm)) : thm = + let + val (_, args) = strip_comb (lhs fun_eq_tm) + val g = list_mk_forall (args, fun_eq_tm) + (* We make a case disjunction of the subgoal: “exists n. even___P i n” *) + val exists_g = (rhs o concl) (PURE_REWRITE_CONV raw_def_thms (lhs fun_eq_tm)) + val (_, exists_g, _) = TypeBase.dest_case exists_g + val prove_tac = + rpt gen_tac >> + Cases_on ‘^exists_g’ + >-( (* Termination *) + irule termination_thm >> pure_asm_rewrite_tac []) + (* Divergence *) + >> irule divergence_thm >> fs [] + + in + save_goal_and_prove (g, prove_tac) + end + in + map prove_one (zip term_div_tms (zip termination_thms divergence_thms)) + end + +(* +val termination_thm = hd termination_thms +val divergence_thm = hd divergence_thms +set_goal ([], g) +*) + +(* The final function: define potentially diverging functions in an error monad *) +fun DefineDiv (def_qt : term quotation) = + let + (* Parse the definitions. + + Example: + {[ + (even (i : int) : bool result = if i = 0 then Return T else odd (i - 1)) /\ + (odd (i : int) : bool result = if i = 0 then Return F else even (i - 1)) + ]} + *) + val def_tms = (strip_conj o list_mk_conj o rev) (Defn.parse_quote def_qt) + + (* Generate definitions which use some fuel + + Example: + {[ + even___fuel n i = + case fuel of + 0 => Diverge + | SUC fuel => + if i = 0 then Return T else odd_fuel (i - 1)) + ]} + *) + (* TODO: list of theorems *) + val fuel_defs_thm = mk_fuel_defs def_tms + + (* Generate the predicate definitions. + + {[ even___P n i = = ~is_diverge (even___fuel n i) ]} + *) + val fuel_def_tms = map (snd o strip_forall) ((strip_conj o concl) fuel_defs_thm) + val pred_def_thms = map mk_fuel_predicate_defs (zip def_tms fuel_def_tms) + + (* Prove the monotonicity property for the fuel, all at once + + *) + val fuel_mono_thm = prove_fuel_mono pred_def_thms fuel_defs_thm + + (* Prove the individual fuel functions - TODO: update + + {[ + !n i. $LEAST (even___P i) <= n ==> even___fuel n i = even___fuel ($LEAST (even___P i)) i + ]} + *) + val least_fuel_mono_thms = prove_least_fuel_mono pred_def_thms fuel_mono_thm + + (* + {[ + !n i. even___P i n ==> $LEAST (even___P i) <= n + ]} + *) + val least_pred_thms = prove_least_pred_thms pred_def_thms + + (* + {[ + !n i. even___P i n ==> even___P i ($LEAST (even___P i)) + ]} + *) + val pred_n_imp_pred_least_thms = prove_pred_n_imp_pred_least_thms pred_def_thms + + (* + "Raw" definitions: + + {[ + even i = if (?n. even___P i n) then even___P ($LEAST (even___P i)) i else Diverge + ]} + *) + val raw_def_thms = define_raw_defs def_tms pred_def_thms fuel_defs_thm + + (* + !n i. even___P i n ==> even___fuel n i = even i + *) + val pred_imp_fuel_eq_raw_def_thms = + prove_pred_imp_fuel_eq_raw_def_thms + pred_def_thms fuel_def_tms least_fuel_mono_thms + least_pred_thms pred_n_imp_pred_least_thms raw_def_thms + + (* "Expand" definitions *) + val expand_defs = gen_expand_defs def_tms + + (* Small utility *) + val term_div_tms = mk_termination_diverge_tms def_tms pred_def_thms raw_def_thms expand_defs + + (* Termination theorems *) + val termination_thms = + prove_termination_thms term_div_tms fuel_defs_thm pred_def_thms + raw_def_thms expand_defs pred_n_imp_pred_least_thms pred_imp_fuel_eq_raw_def_thms + + (* Divergence theorems *) + val divergence_thms = + prove_divergence_thms term_div_tms fuel_defs_thm pred_def_thms raw_def_thms expand_defs + + (* Final theorems: + + {[ + ∀i. even i = even___E even odd i, + ⊢ ∀i. odd i = odd___E even odd i + ]} + *) + val final_eqs = prove_final_eqs term_div_tms termination_thms divergence_thms raw_def_thms + in + (* We return the final equations, which act as rewriting theorems *) + final_eqs + end |