diff options
author | Son Ho | 2023-02-01 00:10:19 +0100 |
---|---|---|
committer | Son HO | 2023-06-04 21:54:38 +0200 |
commit | 49903e84b1193565baa04a25864d6e54fed6f1de (patch) | |
tree | 28b2f49e10f5b97f071192264f6cafa5f2c5ec55 | |
parent | 577484c6c70f6e80e94516b944dd6d9dd06897d0 (diff) |
Fix minor bugs in divDefLib.sml
-rw-r--r-- | backends/hol4/divDefLib.sml | 162 |
1 files changed, 134 insertions, 28 deletions
diff --git a/backends/hol4/divDefLib.sml b/backends/hol4/divDefLib.sml index 351d4bf0..08c8b91a 100644 --- a/backends/hol4/divDefLib.sml +++ b/backends/hol4/divDefLib.sml @@ -6,28 +6,35 @@ open boolTheory arithmeticTheory integerTheory intLib listTheory stringTheory open primitivesArithTheory primitivesBaseTacLib ilistTheory primitivesTheory open primitivesLib -(* TODO: -raw_def_thms -> raw_defs -raw_def_thm -> raw_def +(* TODO: move *) +Theorem case_result_same_eq: + !(r : 'a result). + (case r of + Return x => Return x + | Fail e => Fail e + | Diverge => Diverge) = r +Proof + rw [] >> CASE_TAC +QED -fuel_defs -> fuel_defs (and split the theorem) +(* +val ty = id_ty +strip_arrows ty *) (* TODO: move *) fun list_mk_arrow (tys : hol_type list) (ret_ty : hol_type) : hol_type = foldr (fn (ty, aty) => ty --> aty) ret_ty tys -(* -val def_qt = ‘ - (even (i : int) : bool result = - if i = 0 then Return T else odd (i - 1)) /\ - (odd (i : int) : bool result = - if i = 0 then Return F else even (i - 1)) -’ - -val def_tms = (strip_conj o list_mk_conj o rev) (Defn.parse_quote def_qt) -val def_tm = hd def_tms -*) +(* TODO: move *) +fun strip_arrows (ty : hol_type) : hol_type list * hol_type = + let + val (ty0, ty1) = dom_rng ty + val (tys, ret) = strip_arrows ty1 + in + (ty0::tys, ret) + end + handle HOL_ERR _ => ([], ty) (* Small utilities *) val current_goal : term option ref = ref NONE @@ -317,11 +324,8 @@ fun mk_fuel_predicate_defs (def_tm, fuel_def_tm) : thm = (* From [even i] create the term [even_P i n], where [n] is the fuel *) val (id, args) = (strip_comb o lhs) def_tm val (id_str, id_ty) = dest_var id - val {Args=tys, Thy=thy, Tyop=tyop} = dest_thy_type id_ty - val _ = assert (fn x => x = "fun") tyop; - val tys = rev tys; - val ret_ty = hd tys; - val tys = rev (num_ty :: tl tys); + val (tys, ret_ty) = strip_arrows id_ty + val tys = append tys [num_ty] val pred_ty = list_mk_arrow tys bool_ty val pred_id = mk_var (id_str ^ fuel_predicate_suffix, pred_ty) val pred_tm = list_mk_comb (pred_id, append args [fuel_var]) @@ -338,6 +342,8 @@ fun mk_fuel_predicate_defs (def_tm, fuel_def_tm) : thm = end (* +val (def_tm, fuel_def_tm) = hd (zip def_tms fuel_def_tms) + val pred_defs = map mk_fuel_predicate_defs (zip def_tms fuel_def_tms) *) @@ -362,6 +368,35 @@ fun prove_fuel_mono (pred_defs : thm list) (fuel_defs : thm list) : thm = val pred_tms = map (lhs o snd o strip_forall o concl) pred_defs val fuel_tms = map (lhs o snd o strip_forall o concl) fuel_defs val pred_fuel_tms = zip pred_tms fuel_tms + (* Create a set containing the names of all the functions in the recursive group *) + val rec_fun_set = + Redblackset.fromList const_name_compare (map get_fun_name_from_app fuel_tms) + (* Small tactic which rewrites the occurrences of recursive calls *) + fun rewrite_rec_call (asms, g) = + let + val scrut = (strip_all_cases_get_scrutinee o lhs) g + val fun_id = get_fun_name_from_app scrut (* This can fail *) + in + (* Check if the function is part of the group we are considering *) + if Redblackset.member (rec_fun_set, fun_id) then + let + (* Yes: use the induction hypothesis *) + fun apply_ind_hyp (ind_th : thm) : tactic = + let + val th = SPEC_ALL ind_th + val th_pat = (lhs o snd o strip_imp o concl) th + val (var_s, ty_s) = match_term th_pat scrut + (* Note that in practice the type instantiation should be empty *) + val th = INST var_s (INST_TYPE ty_s th) + in + assume_tac th + end + in + (last_assum apply_ind_hyp >> fs []) (asms, g) + end + else all_tac (asms, g) + end + handle HOL_ERR _ => all_tac (asms, g) (* Generate terms of the shape: !i. n <= m ==> even___P i n ==> even___fuel n i = even___fuel m i *) @@ -379,30 +414,44 @@ fun prove_fuel_mono (pred_defs : thm list) (fuel_defs : thm list) : thm = val fuel_eq_tm = mk_imp (fuel_vars_le, fuel_eq_tm) (* Quantify *) val fuel_eq_tm = list_mk_forall (vars, fuel_eq_tm) - in fuel_eq_tm end + in + fuel_eq_tm + end val fuel_eq_tms = map mk_fuel_eq_tm pred_fuel_tms (* Create the conjunction *) val fuel_eq_tms = list_mk_conj fuel_eq_tms (* Qantify over the fuels *) val fuel_eq_tms = list_mk_forall ([fuel_var0, fuel_var1], fuel_eq_tms) - (* The tactic for the proof *) + (* The tactics for the proof *) val prove_tac = Induct_on ‘^fuel_var0’ >-( (* The ___P predicates are false: n is 0 *) fs pred_defs >> fs [is_diverge_def] >> pure_once_rewrite_tac fuel_defs >> fs []) >> + (* Introduce n *) gen_tac >> + (* Introduce m *) Cases_on ‘^fuel_var1’ >-( (* Contradiction: SUC n < 0 *) rw [] >> exfalso >> int_tac) >> fs pred_defs >> fs [is_diverge_def] >> pure_once_rewrite_tac fuel_defs >> fs [bind_def] >> + (* Introduce in the context *) + rpt gen_tac >> + (* Split the goals - note that we prove one big goal for all the functions at once *) + rpt strip_tac >> + (* Instantiate the assumption: !m. n <= m ==> ~(...) + with the proper m. + *) + last_x_assum imp_res_tac >> + (* Make sure the induction hypothesis is always the last assumption *) + last_x_assum assume_tac >> (* Split the goals *) - rw [] >> + rpt strip_tac >> fs [case_result_same_eq] >> (* Explore all the paths *) - rpt (case_progress >> fs []) + rpt (rewrite_rec_call >> case_progress >> fs [case_result_same_eq]) in (* Prove *) save_goal_and_prove (fuel_eq_tms, prove_tac) @@ -410,6 +459,8 @@ fun prove_fuel_mono (pred_defs : thm list) (fuel_defs : thm list) : thm = (* val fuel_mono_thm = prove_fuel_mono pred_defs fuel_defs + +set_goal ([], fuel_eq_tms) *) (* Prove the property about the least upper bound. @@ -452,6 +503,10 @@ fun prove_least_fuel_mono (pred_defs : thm list) (fuel_mono_thm : thm) : thm lis map mk_least_fuel_thm (zip pred_defs thl) end +(* +val (pred_def, mono_thm) = hd (zip pred_defs thl) +*) + (* Prove theorems of the shape: {[ @@ -843,10 +898,10 @@ fun prove_termination_thms rpt disch_tac >> (* Expand the binds *) - fs [bind_def] >> + fs [bind_def, case_result_same_eq] >> (* Explore all the paths by doing case disjunctions *) - rpt (rewrite_rec_call >> case_progress >> fs []) + rpt (rewrite_rec_call >> case_progress >> fs [case_result_same_eq]) in save_goal_and_prove (tm, prove_tac) end @@ -860,6 +915,9 @@ val termination_thms = prove_termination_thms term_div_tms fuel_defs pred_defs raw_defs expand_defs pred_n_imp_pred_least_thms pred_imp_fuel_eq_raw_defs + +val ((pred_tm, fun_eq_tm), pred_n_imp_pred_least_thm) = hd (zip term_div_tms pred_n_imp_pred_least_thms) +set_goal ([], tm) *) (* Prove the divergence lemmas: @@ -969,10 +1027,10 @@ fun prove_divergence_thms *) pop_assum mp_tac >> pure_once_rewrite_tac fuel_defs >> - rpt disch_tac >> fs [] >> + rpt disch_tac >> fs [bind_def, case_result_same_eq] >> (* Evaluate all the paths *) - rpt (rewrite_rec_call >> case_progress >> fs []) + rpt (rewrite_rec_call >> case_progress >> fs [case_result_same_eq]) in save_goal_and_prove (tm, prove_tac) end @@ -981,6 +1039,9 @@ fun prove_divergence_thms end (* +val (pred_tm, fun_eq_tm) = hd term_div_tms +set_goal ([], tm) + val divergence_thms = prove_divergence_thms term_div_tms @@ -1135,12 +1196,14 @@ fun DefineDiv (def_qt : term quotation) = ]} *) val final_eqs = prove_final_eqs term_div_tms termination_thms divergence_thms raw_defs + val final_eqs = map (PURE_REWRITE_RULE expand_defs) final_eqs in (* We return the final equations, which act as rewriting theorems *) final_eqs end (* + val def_qt = ‘ (even (i : int) : bool result = if i = 0 then Return T else odd (i - 1)) /\ @@ -1150,4 +1213,47 @@ val def_qt = ‘ val even_def = DefineDiv def_qt + +Datatype: + list_t = + ListCons 't list_t + | ListNil +End + +val def_qt = ‘ + nth_mut_fwd (ls : 't list_t) (i : u32) : 't result = + case ls of + | ListCons x tl => + if u32_to_int i = (0:int) + then Return x + else + do + i0 <- u32_sub i (int_to_u32 1); + nth_mut_fwd tl i0 + od + | ListNil => + Fail Failure +’ + +val nth_mut_fwd_def = DefineDiv def_qt + +(* Checking what happens with non terminal calls *) +val def_qt = ‘ + nth_mut_fwd (ls : 't list_t) (i : u32) : 't result = + case ls of + | ListCons x tl => + if u32_to_int i = (0:int) + then Return x + else + do + i0 <- u32_sub i (int_to_u32 1); + x <- nth_mut_fwd tl i0; + Return x + od + | ListNil => + Fail Failure +’ + +val nth_mut_fwd_def = DefineDiv def_qt + *) |