summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-02-01 00:10:19 +0100
committerSon HO2023-06-04 21:54:38 +0200
commit49903e84b1193565baa04a25864d6e54fed6f1de (patch)
tree28b2f49e10f5b97f071192264f6cafa5f2c5ec55
parent577484c6c70f6e80e94516b944dd6d9dd06897d0 (diff)
Fix minor bugs in divDefLib.sml
-rw-r--r--backends/hol4/divDefLib.sml162
1 files changed, 134 insertions, 28 deletions
diff --git a/backends/hol4/divDefLib.sml b/backends/hol4/divDefLib.sml
index 351d4bf0..08c8b91a 100644
--- a/backends/hol4/divDefLib.sml
+++ b/backends/hol4/divDefLib.sml
@@ -6,28 +6,35 @@ open boolTheory arithmeticTheory integerTheory intLib listTheory stringTheory
open primitivesArithTheory primitivesBaseTacLib ilistTheory primitivesTheory
open primitivesLib
-(* TODO:
-raw_def_thms -> raw_defs
-raw_def_thm -> raw_def
+(* TODO: move *)
+Theorem case_result_same_eq:
+ !(r : 'a result).
+ (case r of
+ Return x => Return x
+ | Fail e => Fail e
+ | Diverge => Diverge) = r
+Proof
+ rw [] >> CASE_TAC
+QED
-fuel_defs -> fuel_defs (and split the theorem)
+(*
+val ty = id_ty
+strip_arrows ty
*)
(* TODO: move *)
fun list_mk_arrow (tys : hol_type list) (ret_ty : hol_type) : hol_type =
foldr (fn (ty, aty) => ty --> aty) ret_ty tys
-(*
-val def_qt = ‘
- (even (i : int) : bool result =
- if i = 0 then Return T else odd (i - 1)) /\
- (odd (i : int) : bool result =
- if i = 0 then Return F else even (i - 1))
-’
-
-val def_tms = (strip_conj o list_mk_conj o rev) (Defn.parse_quote def_qt)
-val def_tm = hd def_tms
-*)
+(* TODO: move *)
+fun strip_arrows (ty : hol_type) : hol_type list * hol_type =
+ let
+ val (ty0, ty1) = dom_rng ty
+ val (tys, ret) = strip_arrows ty1
+ in
+ (ty0::tys, ret)
+ end
+ handle HOL_ERR _ => ([], ty)
(* Small utilities *)
val current_goal : term option ref = ref NONE
@@ -317,11 +324,8 @@ fun mk_fuel_predicate_defs (def_tm, fuel_def_tm) : thm =
(* From [even i] create the term [even_P i n], where [n] is the fuel *)
val (id, args) = (strip_comb o lhs) def_tm
val (id_str, id_ty) = dest_var id
- val {Args=tys, Thy=thy, Tyop=tyop} = dest_thy_type id_ty
- val _ = assert (fn x => x = "fun") tyop;
- val tys = rev tys;
- val ret_ty = hd tys;
- val tys = rev (num_ty :: tl tys);
+ val (tys, ret_ty) = strip_arrows id_ty
+ val tys = append tys [num_ty]
val pred_ty = list_mk_arrow tys bool_ty
val pred_id = mk_var (id_str ^ fuel_predicate_suffix, pred_ty)
val pred_tm = list_mk_comb (pred_id, append args [fuel_var])
@@ -338,6 +342,8 @@ fun mk_fuel_predicate_defs (def_tm, fuel_def_tm) : thm =
end
(*
+val (def_tm, fuel_def_tm) = hd (zip def_tms fuel_def_tms)
+
val pred_defs = map mk_fuel_predicate_defs (zip def_tms fuel_def_tms)
*)
@@ -362,6 +368,35 @@ fun prove_fuel_mono (pred_defs : thm list) (fuel_defs : thm list) : thm =
val pred_tms = map (lhs o snd o strip_forall o concl) pred_defs
val fuel_tms = map (lhs o snd o strip_forall o concl) fuel_defs
val pred_fuel_tms = zip pred_tms fuel_tms
+ (* Create a set containing the names of all the functions in the recursive group *)
+ val rec_fun_set =
+ Redblackset.fromList const_name_compare (map get_fun_name_from_app fuel_tms)
+ (* Small tactic which rewrites the occurrences of recursive calls *)
+ fun rewrite_rec_call (asms, g) =
+ let
+ val scrut = (strip_all_cases_get_scrutinee o lhs) g
+ val fun_id = get_fun_name_from_app scrut (* This can fail *)
+ in
+ (* Check if the function is part of the group we are considering *)
+ if Redblackset.member (rec_fun_set, fun_id) then
+ let
+ (* Yes: use the induction hypothesis *)
+ fun apply_ind_hyp (ind_th : thm) : tactic =
+ let
+ val th = SPEC_ALL ind_th
+ val th_pat = (lhs o snd o strip_imp o concl) th
+ val (var_s, ty_s) = match_term th_pat scrut
+ (* Note that in practice the type instantiation should be empty *)
+ val th = INST var_s (INST_TYPE ty_s th)
+ in
+ assume_tac th
+ end
+ in
+ (last_assum apply_ind_hyp >> fs []) (asms, g)
+ end
+ else all_tac (asms, g)
+ end
+ handle HOL_ERR _ => all_tac (asms, g)
(* Generate terms of the shape:
!i. n <= m ==> even___P i n ==> even___fuel n i = even___fuel m i
*)
@@ -379,30 +414,44 @@ fun prove_fuel_mono (pred_defs : thm list) (fuel_defs : thm list) : thm =
val fuel_eq_tm = mk_imp (fuel_vars_le, fuel_eq_tm)
(* Quantify *)
val fuel_eq_tm = list_mk_forall (vars, fuel_eq_tm)
- in fuel_eq_tm end
+ in
+ fuel_eq_tm
+ end
val fuel_eq_tms = map mk_fuel_eq_tm pred_fuel_tms
(* Create the conjunction *)
val fuel_eq_tms = list_mk_conj fuel_eq_tms
(* Qantify over the fuels *)
val fuel_eq_tms = list_mk_forall ([fuel_var0, fuel_var1], fuel_eq_tms)
- (* The tactic for the proof *)
+ (* The tactics for the proof *)
val prove_tac =
Induct_on ‘^fuel_var0’ >-(
(* The ___P predicates are false: n is 0 *)
fs pred_defs >>
fs [is_diverge_def] >>
pure_once_rewrite_tac fuel_defs >> fs []) >>
+ (* Introduce n *)
gen_tac >>
+ (* Introduce m *)
Cases_on ‘^fuel_var1’ >-(
(* Contradiction: SUC n < 0 *)
rw [] >> exfalso >> int_tac) >>
fs pred_defs >>
fs [is_diverge_def] >>
pure_once_rewrite_tac fuel_defs >> fs [bind_def] >>
+ (* Introduce in the context *)
+ rpt gen_tac >>
+ (* Split the goals - note that we prove one big goal for all the functions at once *)
+ rpt strip_tac >>
+ (* Instantiate the assumption: !m. n <= m ==> ~(...)
+ with the proper m.
+ *)
+ last_x_assum imp_res_tac >>
+ (* Make sure the induction hypothesis is always the last assumption *)
+ last_x_assum assume_tac >>
(* Split the goals *)
- rw [] >>
+ rpt strip_tac >> fs [case_result_same_eq] >>
(* Explore all the paths *)
- rpt (case_progress >> fs [])
+ rpt (rewrite_rec_call >> case_progress >> fs [case_result_same_eq])
in
(* Prove *)
save_goal_and_prove (fuel_eq_tms, prove_tac)
@@ -410,6 +459,8 @@ fun prove_fuel_mono (pred_defs : thm list) (fuel_defs : thm list) : thm =
(*
val fuel_mono_thm = prove_fuel_mono pred_defs fuel_defs
+
+set_goal ([], fuel_eq_tms)
*)
(* Prove the property about the least upper bound.
@@ -452,6 +503,10 @@ fun prove_least_fuel_mono (pred_defs : thm list) (fuel_mono_thm : thm) : thm lis
map mk_least_fuel_thm (zip pred_defs thl)
end
+(*
+val (pred_def, mono_thm) = hd (zip pred_defs thl)
+*)
+
(* Prove theorems of the shape:
{[
@@ -843,10 +898,10 @@ fun prove_termination_thms
rpt disch_tac >>
(* Expand the binds *)
- fs [bind_def] >>
+ fs [bind_def, case_result_same_eq] >>
(* Explore all the paths by doing case disjunctions *)
- rpt (rewrite_rec_call >> case_progress >> fs [])
+ rpt (rewrite_rec_call >> case_progress >> fs [case_result_same_eq])
in
save_goal_and_prove (tm, prove_tac)
end
@@ -860,6 +915,9 @@ val termination_thms =
prove_termination_thms term_div_tms fuel_defs pred_defs
raw_defs expand_defs pred_n_imp_pred_least_thms
pred_imp_fuel_eq_raw_defs
+
+val ((pred_tm, fun_eq_tm), pred_n_imp_pred_least_thm) = hd (zip term_div_tms pred_n_imp_pred_least_thms)
+set_goal ([], tm)
*)
(* Prove the divergence lemmas:
@@ -969,10 +1027,10 @@ fun prove_divergence_thms
*)
pop_assum mp_tac >>
pure_once_rewrite_tac fuel_defs >>
- rpt disch_tac >> fs [] >>
+ rpt disch_tac >> fs [bind_def, case_result_same_eq] >>
(* Evaluate all the paths *)
- rpt (rewrite_rec_call >> case_progress >> fs [])
+ rpt (rewrite_rec_call >> case_progress >> fs [case_result_same_eq])
in
save_goal_and_prove (tm, prove_tac)
end
@@ -981,6 +1039,9 @@ fun prove_divergence_thms
end
(*
+val (pred_tm, fun_eq_tm) = hd term_div_tms
+set_goal ([], tm)
+
val divergence_thms =
prove_divergence_thms
term_div_tms
@@ -1135,12 +1196,14 @@ fun DefineDiv (def_qt : term quotation) =
]}
*)
val final_eqs = prove_final_eqs term_div_tms termination_thms divergence_thms raw_defs
+ val final_eqs = map (PURE_REWRITE_RULE expand_defs) final_eqs
in
(* We return the final equations, which act as rewriting theorems *)
final_eqs
end
(*
+
val def_qt = ‘
(even (i : int) : bool result =
if i = 0 then Return T else odd (i - 1)) /\
@@ -1150,4 +1213,47 @@ val def_qt = ‘
val even_def = DefineDiv def_qt
+
+Datatype:
+ list_t =
+ ListCons 't list_t
+ | ListNil
+End
+
+val def_qt = ‘
+ nth_mut_fwd (ls : 't list_t) (i : u32) : 't result =
+ case ls of
+ | ListCons x tl =>
+ if u32_to_int i = (0:int)
+ then Return x
+ else
+ do
+ i0 <- u32_sub i (int_to_u32 1);
+ nth_mut_fwd tl i0
+ od
+ | ListNil =>
+ Fail Failure
+’
+
+val nth_mut_fwd_def = DefineDiv def_qt
+
+(* Checking what happens with non terminal calls *)
+val def_qt = ‘
+ nth_mut_fwd (ls : 't list_t) (i : u32) : 't result =
+ case ls of
+ | ListCons x tl =>
+ if u32_to_int i = (0:int)
+ then Return x
+ else
+ do
+ i0 <- u32_sub i (int_to_u32 1);
+ x <- nth_mut_fwd tl i0;
+ Return x
+ od
+ | ListNil =>
+ Fail Failure
+’
+
+val nth_mut_fwd_def = DefineDiv def_qt
+
*)