diff options
author | Son Ho | 2023-05-12 18:56:39 +0200 |
---|---|---|
committer | Son HO | 2023-06-04 21:54:38 +0200 |
commit | c49fd4b6230a1f926e929f133794b6f73d338077 (patch) | |
tree | 85de1862bbfb7f79bbe0b83722ae76ef1a52dead | |
parent | 2e5415364d3ad61a524980621d25713aa31f6e79 (diff) |
Reimplement DefineDiv with the new fixed point operator
Diffstat (limited to '')
-rw-r--r-- | backends/hol4/divDefProto2Script.sml | 1 | ||||
-rw-r--r-- | backends/hol4/divDefProto2TestScript.sml | 1245 | ||||
-rw-r--r-- | backends/hol4/primitivesBaseTacLib.sml | 5 |
3 files changed, 1249 insertions, 2 deletions
diff --git a/backends/hol4/divDefProto2Script.sml b/backends/hol4/divDefProto2Script.sml index 9dc43ea7..9efe835b 100644 --- a/backends/hol4/divDefProto2Script.sml +++ b/backends/hol4/divDefProto2Script.sml @@ -8,7 +8,6 @@ open primitivesLib val _ = new_theory "divDefProto2" - (* * Test with a general validity predicate. * diff --git a/backends/hol4/divDefProto2TestScript.sml b/backends/hol4/divDefProto2TestScript.sml new file mode 100644 index 00000000..39719b65 --- /dev/null +++ b/backends/hol4/divDefProto2TestScript.sml @@ -0,0 +1,1245 @@ +open HolKernel boolLib bossLib Parse +open boolTheory arithmeticTheory integerTheory intLib listTheory stringTheory + +open primitivesArithTheory primitivesBaseTacLib ilistTheory primitivesTheory +open primitivesLib +open divDefProto2Theory + +val _ = new_theory "divDefProto2TestScript" + +(*====================== + * Example 1: nth + *======================*) +Datatype: + list_t = + ListCons 't list_t + | ListNil +End + +(* We use this version of the body to prove that the body is valid *) +val nth_body_def = Define ‘ + nth_body (f : (('t list_t # u32) + 't) -> (('t list_t # u32) + 't) result) + (x : (('t list_t # u32) + 't)) : + (('t list_t # u32) + 't) result = + (* Destruct the input. We need this to call the proper function in case + of mutually recursive definitions, but also to eliminate arguments + which correspond to the output value (the input type is the same + as the output type). *) + case x of + | INL x => ( + let (ls, i) = x in + case ls of + | ListCons x tl => + if u32_to_int i = (0:int) + then Return (INR x) + else + do + i0 <- u32_sub i (int_to_u32 1); + r <- f (INL (tl, i0)); + (* Eliminate the invalid outputs. This is not necessary here, + but it is in the case of non tail call recursive calls. *) + case r of + | INL _ => Fail Failure + | INR i1 => Return (INR i1) + od + | ListNil => Fail Failure) + | INR _ => Fail Failure +’ + +val dbg = ref false +fun print_dbg s = if (!dbg) then print s else () + +(* Tactic which makes progress in a proof of validity by making a case + disjunction (we use this to explore all the paths in a function body). *) +fun prove_valid_case_progress + +(* +val (asms, g) = top_goal () +*) + +(* TODO: move *) +(* This theorem is important to shape the goal when proving that a body + satifies the fixed point validity property. + + Importantly: this theorem relies on the fact that errors are just transmitted + to the caller (in particular, without modification). + *) +Theorem case_result_switch_eq: + (case (case x of Return y => f y | Fail e => Fail e | Diverge => Diverge) of + | Return y => g y + | Fail e => Fail e + | Diverge => Diverge) = + (case x of + | Return y => + (case f y of + | Return y => g y + | Fail e => Fail e + | Diverge => Diverge) + | Fail e => Fail e + | Diverge => Diverge) +Proof + Cases_on ‘x’ >> fs [] +QED + +(* Tactic to prove that a body is valid: perform one step. *) +fun prove_body_is_valid_tac_step (asms, g) = + let + (* The goal has the shape: + {[ + (∀g h. ... g x = ... h x) ∨ + ∃h y. is_valid_fp_body n h ∧ ∀g. ... g x = ... od + ]} + *) + (* Retrieve the scrutinee in the goal (‘x’). + There are two cases: + - either the function has the shape: + {[ + (λ(y,z). ...) x + ]} + in which case we need to destruct ‘x’ + - or we have a normal ‘case ... of’ + *) + val body = (lhs o snd o strip_forall o fst o dest_disj) g + val scrut = + let + val (app, x) = dest_comb body + val (app, _) = dest_comb app + val {Name=name, Thy=thy, Ty = _ } = dest_thy_const app + in + if thy = "pair" andalso name = "UNCURRY" then x else failwith "not a curried argument" + end + handle HOL_ERR _ => strip_all_cases_get_scrutinee body + (* Retrieve the first quantified continuations from the goal (‘g’) *) + val qc = (hd o fst o strip_forall o fst o dest_disj) g + (* Check if the scrutinee is a recursive call *) + val (scrut_app, _) = strip_comb scrut + val _ = print_dbg ("prove_body_is_valid_step: Scrutinee: " ^ term_to_string scrut ^ "\n") + (* For the recursive calls: *) + fun step_rec () = + let + val _ = print_dbg ("prove_body_is_valid_step: rec call\n") + (* We need to instantiate the ‘h’ existantially quantified function *) + (* First, retrieve the body of the function: it is given by the ‘Return’ branch *) + val (_, _, branches) = TypeBase.dest_case body + (* Find the branch corresponding to the return *) + val ret_branch = List.find (fn (pat, _) => + let + val {Name=name, Thy=thy, Ty = _ } = (dest_thy_const o fst o strip_comb) pat + in + thy = "primitives" andalso name = "Return" + end) branches + val var = (hd o snd o strip_comb o fst o valOf) ret_branch + val br = (snd o valOf) ret_branch + (* Abstract away the input variable introduced by the pattern and the continuation ‘g’ *) + val h = list_mk_abs ([qc, var], br) + val _ = print_dbg ("prove_body_is_valid_step: h: " ^ term_to_string h ^ "\n") + (* Retrieve the input parameter ‘x’ *) + val input = (snd o dest_comb) scrut + val _ = print_dbg ("prove_body_is_valid_step: y: " ^ term_to_string input ^ "\n") + in + ((* Choose the right possibility (this is a recursive call) *) + disj2_tac >> + (* Instantiate the quantifiers *) + qexists ‘^h’ >> + qexists ‘^input’ >> + (* Unfold the predicate once *) + pure_once_rewrite_tac [is_valid_fp_body_def] >> + (* We have two subgoals: + - we have to prove that ‘h’ is valid + - we have to finish the proof of validity for the current body + *) + conj_tac >> fs [case_result_switch_eq]) + end + in + (* If recursive call: special treatment. Otherwise, we do a simple disjunction *) + (if term_eq scrut_app qc then step_rec () + else (Cases_on ‘^scrut’ >> fs [case_result_switch_eq])) (asms, g) + end + +(* Tactic to prove that a body is valid *) +fun prove_body_is_valid_tac (body_def : thm option) : tactic = + let val body_def_thm = case body_def of SOME th => [th] | NONE => [] + in + pure_once_rewrite_tac [is_valid_fp_body_def] >> gen_tac >> + (* Expand *) + fs body_def_thm >> + fs [bind_def, case_result_switch_eq] >> + (* Explore the body *) + rpt prove_body_is_valid_tac_step + end + +(* TODO: move *) +val is_valid_fp_body_tm = “is_valid_fp_body” + +(* Prove that a body satisfies the validity condition of the fixed point *) +fun prove_body_is_valid (body : term) : thm = + let + (* Explore the body and count the number of occurrences of nested recursive + calls so that we can properly instantiate the ‘N’ argument of ‘is_valid_fp_body’. + + We first retrieve the name of the continuation parameter. + Rem.: we generated fresh names so that, for instance, the continuation name + doesn't collide with other names. Because of this, we don't need to look for + collisions when exploring the body (and in the worst case, we would cound + an overapproximation of the number of recursive calls, which is perfectly + valid). + *) + val fcont = (hd o fst o strip_abs) body + val fcont_name = (fst o dest_var) fcont + fun max x y = if x > y then x else y + fun count_body_rec_calls (body : term) : int = + case dest_term body of + VAR (name, _) => if name = fcont_name then 1 else 0 + | CONST _ => 0 + | COMB (x, y) => max (count_body_rec_calls x) (count_body_rec_calls y) + | LAMB (_, x) => count_body_rec_calls x + val num_rec_calls = count_body_rec_calls body + + (* Generate the term ‘SUC (SUC ... (SUC n))’ where ‘n’ is a fresh variable. + + Remark: we first prove ‘is_valid_fp_body (SUC ... n) body’ then substitue + ‘n’ with ‘0’ to prevent the quantity from being rewritten to a bit + representation, which would prevent unfolding of the ‘is_valid_fp_body’. + *) + val nvar = genvar num_ty + (* Rem.: we stack num_rec_calls + 1 occurrences of ‘SUC’ (and the + 1 is important) *) + fun mk_n i = if i = 0 then mk_suc nvar else mk_suc (mk_n (i-1)) + val n_tm = mk_n num_rec_calls + + (* Generate the lemma statement *) + val is_valid_tm = list_mk_icomb (is_valid_fp_body_tm, [n_tm, body]) + val is_valid_thm = prove (is_valid_tm, prove_body_is_valid_tac NONE) + + (* Replace ‘nvar’ with ‘0’ *) + val is_valid_thm = INST [nvar |-> zero_num_tm] is_valid_thm + in + is_valid_thm + end + +(* +val (asms, g) = top_goal () +*) + +(* We first prove the theorem with ‘SUC (SUC n)’ where ‘n’ is a variable + to prevent this quantity from being rewritten to 2 *) +Theorem nth_body_is_valid_aux: + is_valid_fp_body (SUC (SUC n)) nth_body +Proof + prove_body_is_valid_tac (SOME nth_body_def) +QED + +Theorem nth_body_is_valid: + is_valid_fp_body (SUC (SUC 0)) nth_body +Proof + irule nth_body_is_valid_aux +QED + +val nth_raw_def = Define ‘ + nth (ls : 't list_t) (i : u32) = + case fix nth_body (INL (ls, i)) of + | Fail e => Fail e + | Diverge => Diverge + | Return r => + case r of + | INL _ => Fail Failure + | INR x => Return x +’ + +val fix_tm = “fix” + +(* Generate the raw definitions by using the grouped definition body and the + fixed point operator *) +fun mk_raw_defs (in_out_tys : (hol_type * hol_type) list) + (def_tms : term list) (body_is_valid : thm) : thm list = + let + (* Retrieve the body *) + val body = (List.last o snd o strip_comb o concl) body_is_valid + + (* Create the term ‘fix body’ *) + val fixed_body = mk_icomb (fix_tm, body) + + (* For every function in the group, generate the equation that we will + use as definition. In particular: + - add the properly injected input ‘x’ to ‘fix body’ (ex.: for ‘nth ls i’ + we add the input ‘INL (ls, i)’) + - wrap ‘fix body x’ into a case disjunction to extract the relevant output + + For instance, in the case of ‘nth ls i’: + {[ + nth (ls : 't list_t) (i : u32) = + case fix nth_body (INL (ls, i)) of + | Fail e => Fail e + | Diverge => Diverge + | Return r => + case r of + | INL _ => Fail Failure + | INR x => Return x + ]} + *) + fun mk_def_eq (i : int, def_tm : term) : term = + let + (* Retrieve the lhs of the original definition equation, and in + particular the inputs *) + val def_lhs = lhs def_tm + val args = (snd o strip_comb) def_lhs + + (* Inject the inputs into the param type *) + val input = pairSyntax.list_mk_pair args + val input = inject_in_param_sum in_out_tys i true input + + (* Compose*) + val def_rhs = mk_comb (fixed_body, input) + + (* Wrap in the case disjunction *) + val def_rhs = mk_case_select_result_sum def_rhs in_out_tys i false + + (* Create the equation *) + val def_eq_tm = mk_eq (def_lhs, def_rhs) + in + def_eq_tm + end + val raw_def_tms = map mk_def_eq (enumerate def_tms) + + (* Generate the definitions *) + val raw_defs = map (fn tm => Define ‘^tm’) raw_def_tms + in + raw_defs + end + +(* Tactic which makes progress in a proof by making a case disjunction (we use + this to explore all the paths in a function body). *) +fun case_progress (asms, g) = + let + val scrut = (strip_all_cases_get_scrutinee o lhs) g + in Cases_on ‘^scrut’ (asms, g) end + +(* Prove the final equation, that we will use as definition. *) +fun prove_def_eq_tac + (current_raw_def : thm) (all_raw_defs : thm list) (is_valid : thm) + (body_def : thm option) : tactic = + let + val body_def_thm = case body_def of SOME th => [th] | NONE => [] + in + rpt gen_tac >> + (* Expand the definition *) + pure_once_rewrite_tac [current_raw_def] >> + (* Use the fixed-point equality *) + pure_once_rewrite_left_tac [HO_MATCH_MP fix_fixed_eq is_valid] >> + (* Expand the body definition *) + pure_rewrite_tac body_def_thm >> + (* Expand all the definitions from the group *) + pure_rewrite_tac all_raw_defs >> + (* Explore all the paths - maybe we can be smarter, but this is fast and really easy *) + fs [bind_def] >> + rpt (case_progress >> fs []) + end + +(* Prove the final equations that we will give to the user as definitions *) +fun prove_def_eqs (body_is_valid : thm) (def_tms : term list) (raw_defs : thm list) : thm list= + let + val defs_tgt_raw = zip def_tms raw_defs + (* Substitute the function variables with the constants introduced in the raw + definitions *) + fun compute_fsubst (def_tm, raw_def) : {redex: term, residue: term} = + let + val (fvar, _) = (strip_comb o lhs) def_tm + val fconst = (fst o strip_comb o lhs o snd o strip_forall o concl) raw_def + in + (fvar |-> fconst) + end + val fsubst = map compute_fsubst defs_tgt_raw + val defs_tgt_raw = map (fn (x, y) => (subst fsubst x, y)) defs_tgt_raw + + fun prove_def_eq (def_tm, raw_def) : thm = + let + (* Quantify the parameters *) + val (_, params) = (strip_comb o lhs) def_tm + val def_eq_tm = list_mk_forall (params, def_tm) + (* Prove *) + val def_eq = prove (def_eq_tm, prove_def_eq_tac raw_def raw_defs body_is_valid NONE) + in + def_eq + end + val def_eqs = map prove_def_eq defs_tgt_raw + in + def_eqs + end + +Theorem nth_def: + ∀ls i. nth (ls : 't list_t) (i : u32) : 't result = + case ls of + | ListCons x tl => + if u32_to_int i = (0:int) + then (Return x) + else + do + i0 <- u32_sub i (int_to_u32 1); + nth tl i0 + od + | ListNil => Fail Failure +Proof + prove_def_eq_tac nth_raw_def [nth_raw_def] nth_body_is_valid nth_body_def +QED + +(*====================== + * Example 2: even, odd + *======================*) + +val def_qt = ‘ + (even (i : int) : bool result = + if i = 0 then Return T else odd (i - 1)) /\ + (odd (i : int) : bool result = + if i = 0 then Return F else even (i - 1)) +’ + +val result_ty = “:'a result” +val error_ty = “:error” +val alpha_ty = “:'a” +val num_ty = “:num” + +val return_tm = “Return : 'a -> 'a result” +val fail_tm = “Fail : error -> 'a result” +val fail_failure_tm = “Fail Failure : 'a result” +val diverge_tm = “Diverge : 'a result” + +val zero_num_tm = “0:num” +val suc_tm = “SUC” + +fun mk_result (ty : hol_type) : hol_type = Type.type_subst [ alpha_ty |-> ty ] result_ty +fun dest_result (ty : hol_type) : hol_type = + let + val {Args=out_ty, Thy=thy, Tyop=tyop} = dest_thy_type ty + in + if thy = "primitives" andalso tyop = "result" then hd out_ty + else failwith "dest_result: not a result" + end + +fun mk_return (x : term) : term = mk_icomb (return_tm, x) +fun mk_fail (ty : hol_type) (e : term) : term = mk_comb (inst [ alpha_ty |-> ty ] fail_tm, e) +fun mk_fail_failure (ty : hol_type) : term = inst [ alpha_ty |-> ty ] fail_failure_tm +fun mk_diverge (ty : hol_type) : term = inst [ alpha_ty |-> ty ] diverge_tm + +fun mk_suc (n : term) = mk_comb (suc_tm, n) + +(* + *) + +(* **BODY GENERATION**: + ==================== + + When generating a recursive definition, we apply a fixed-point operator to + a function body. In case we define a group of mutually recursive definitions, + we generate *one* single body for the whole group of definitions. It works + as follows. + + The input of the body is an enumeration: we start by branching over this input, and + every branch corresponds to one function in the mutually recursive group. Also, the + inputs must be grouped into tuples. Whenever we make a recursive call, we wrap the + input parameters into the proper variant, so as to call the proper function. + + Moreover, the input of the body must have the same type as its output: we also + store the outputs of the functions in some variants of the enumeration. + + In order to make this work, we need to shape the body so that: + - input values/output values are properly injected into the enumeration + - whenever we get an output value (which is an enumeration), we extract + the value from the proper variant of the enumeration + + We encode the enumeration with a nested sum type, whose constructors + are ‘INL’ and ‘INR’. + + Example: + ======== + We consider the following group of mutually recursive definitions: *) + +val even_odd_qt = Defn.parse_quote ‘ + (even (i : int) : bool result = if i = 0 then Return T else odd (i - 1)) /\ + (odd (i : int) : bool result = if i = 0 then Return F else even (i - 1)) +’ + +(* From those equations, we generate the following body: *) + +val even_odd_body_def = Define ‘ + even_odd_body + (* The body takes a continuation - required by the fixed-point operator *) + (f : (int + bool + int + bool) -> (int + bool + int + bool) result) + + (* The type of the input is: + input of even + output of even + input of odd + output of odd *) + (x : int + bool + int + bool) : + + (* The output type is the same as the input type - this constraint + comes from limitations in the way we can define the fixed-point + operator inside the HOL logic *) + (int + bool + int + bool) result = + + (* Case disjunction over the input, in order to figure out which + function from the group is actually called (even, or odd). *) + case x of + | INL i => (* Input of even *) + (* Body of even *) + if i = 0 then Return (INR (INL T)) + else + (* Recursive calls are calls to the continuation f, wrapped + in the proper variant: here we call odd *) + (case f (INR (INR (INL (i - 1)))) of + | Fail e => Fail e + | Diverge => Diverge + | Return r => + (* Extract the proper value from the enumeration: here, the + call is tail-call so this is not really necessary, but we + might need to retrieve the output of the call to odd, which + is a boolean, and do something more complex with it. *) + case r of + | INL _ => Fail Failure + | INR (INL _) => Fail Failure + | INR (INR (INL _)) => Fail Failure + | INR (INR (INR b)) => (* Extract the output of odd *) + (* Return: inject into the variant for the output of even *) + Return (INR (INL b)) + ) + | INR (INL _) => (* Output of even *) + (* We must ignore this one *) + Fail Failure + | INR (INR (INL i)) => + (* Body of odd *) + if i = 0 then Return (INR (INR (INR F))) + else + (* Call to even *) + (case f (INL (i - 1)) of + | Fail e => Fail e + | Diverge => Diverge + | Return r => + (* Extract the proper value from the enumeration *) + case r of + | INL _ => Fail Failure + | INR (INL b) => (* Extract the output of even *) + (* Return: inject into the variant for the output of odd *) + Return (INR (INR (INR b))) + | INR (INR (INL _)) => Fail Failure + | INR (INR (INR _)) => Fail Failure + ) + | INR (INR (INR _)) => (* Output of odd *) + (* We must ignore this one *) + Fail Failure +’ + +(* Small helper to generate wrappers of the shape: ‘INL x’, ‘INR (INL x)’, etc. + Note that we should have: ‘length before_tys + 1 + length after tys >= 2’ + + Ex.: + ==== + The enumeration has type: “: 'a + 'b + 'c + 'd”. + We want to generate the variant which injects “x:'c” into this enumeration. + + We need to split the list of types into: + {[ + before_tys = [“:'a”, “'b”] + tm = “x: 'c” + after_tys = [“:'d”] + ]} + + The function will generate: + {[ + INR (INR (INL x) : 'a + 'b + 'c + 'd + ]} + + (* Debug *) + val before_tys = [“:'a”, “:'b”, “:'c”] + val tm = “x:'d” + val after_tys = [“:'e”, “:'f”] + + val before_tys = [“:'a”, “:'b”, “:'c”] + val tm = “x:'d” + val after_tys = [] + + mk_inl_inr_wrapper before_tys tm after_tys + *) +fun list_mk_inl_inr (before_tys : hol_type list) (tm : term) (after_tys : hol_type list) : + term = + let + val (before_tys, pat) = + if after_tys = [] + then + let + val just_before_ty = List.last before_tys + val before_tys = List.take (before_tys, List.length before_tys - 1) + val pat = sumSyntax.mk_inr (tm, just_before_ty) + in + (before_tys, pat) + end + else (before_tys, sumSyntax.mk_inl (tm, sumSyntax.list_mk_sum after_tys)) + val pat = foldr (fn (ty, pat) => sumSyntax.mk_inr (pat, ty)) pat before_tys + in + pat + end + + +(* This function wraps a term into the proper variant of the input/output + sum. + + Ex.: + ==== + For the input of the first function, we generate: ‘INL x’ + For the output of the first function, we generate: ‘INR (INL x)’ + For the input of the 2nd function, we generate: ‘INR (INR (INL x))’ + etc. + + If ‘is_input’ is true: we are wrapping an input. Otherwise we are wrapping + an output. + + (* Debug *) + val tys = [(“:'a”, “:'b”), (“:'c”, “:'d”), (“:'e”, “:'f”)] + val j = 1 + val tm = “x:'c” + val tm = “y:'d” + val is_input = true + *) +fun inject_in_param_sum (tys : (hol_type * hol_type) list) (j : int) (is_input : bool) + (tm : term) : term = + let + fun flatten ls = List.concat (map (fn (x, y) => [x, y]) ls) + val before_tys = flatten (List.take (tys, j)) + val (input_ty, output_ty) = List.nth (tys, j) + val after_tys = flatten (List.drop (tys, j + 1)) + val (before_tys, after_tys) = + if is_input then (before_tys, output_ty :: after_tys) + else (before_tys @ [input_ty], after_tys) + in + list_mk_inl_inr before_tys tm after_tys + end + +(* Remark: the order of the branches when creating matches is important. + For instance, in the case of ‘result’ it must be: ‘Return’, ‘Fail’, ‘Diverge’. + + For the purpose of stability and maintainability, we introduce this small helper + which reorders the cases in a pattern before actually creating the case + expression. + *) +fun unordered_mk_case (scrut: term, pats: (term * term) list) : term = + let + (* Retrieve the constructors *) + val cl = TypeBase.constructors_of (type_of scrut) + (* Retrieve the names of the constructors *) + val names = map (fst o dest_const) cl + (* Use those to reorder the patterns *) + fun is_pat (name : string) (pat, _) = + let + val app = (fst o strip_comb) pat + val app_name = (fst o dest_const) app + in + app_name = name + end + val pats = map (fn name => valOf (List.find (is_pat name) pats)) names + in + (* Create the case *) + TypeBase.mk_case (scrut, pats) + end + +(* Wrap a term of type “:'a result” into a ‘case of’ which matches over + the result. + + Ex.: + ==== + {[ + f x + + ~~> + + case f x of + | Fail e => Fail e + | Diverge => Diverge + | Return y => ... (* The branch content is generated by the continuation *) + ]} + + ‘gen_ret_branch’ is a *continuation* which generates the content of the + ‘Return’ branch (i.e., the content of the ‘...’ in the example above). + It receives as input the value contained by the ‘Return’ (i.e., the variable + ‘y’ in the example above). + + Remark.: the type of the term generated by ‘gen_ret_branch’ must have + the type ‘result’, but it can change the content of the result (i.e., + if ‘scrut’ has type ‘:'a result’, we can change the type of the wrapped + expression to ‘:'b result’). + + (* Debug *) + val scrut = “x: int result” + fun gen_ret_branch x = mk_return x + + val scrut = “x: int result” + fun gen_ret_branch _ = “Return T” + + mk_result_case scrut gen_ret_branch + *) +fun mk_result_case (scrut : term) (gen_ret_branch : term -> term) : term = + let + val scrut_ty = dest_result (type_of scrut) + (* Return branch *) + val ret_var = genvar scrut_ty + val ret_pat = mk_return ret_var + val ret_br = gen_ret_branch ret_var + val ret_ty = dest_result (type_of ret_br) + (* Failure branch *) + val fail_var = genvar error_ty + val fail_pat = mk_fail scrut_ty fail_var + val fail_br = mk_fail ret_ty fail_var + (* Diverge branch *) + val div_pat = mk_diverge scrut_ty + val div_br = mk_diverge ret_ty + in + unordered_mk_case (scrut, [(ret_pat, ret_br), (fail_pat, fail_br), (div_pat, div_br)]) + end + +(* Generate a ‘case ... of’ over a sum type. + + Ex.: + ==== + If the scrutinee is: “x : 'a + 'b + 'c” (i.e., the tys list is: [“:'a”, “:b”, “:c”]), + we generate: + + {[ + case x of + | INL y0 => ... (* Branch of index 0 *) + | INR (INL y1) => ... (* Branch of index 1 *) + | INR (INR (INL y2)) => ... (* Branch of index 2 *) + | INR (INR (INR y3)) => ... (* Branch of index 3 *) + ]} + + The content of the branches is generated by the ‘gen_branch’ continuation, + which receives as input the index of the branch as well as the variable + introduced by the pattern (in the example above: ‘y0’ for the branch 0, + ‘y1’ for the branch 1, etc.) + + (* Debug *) + val tys = [“:'a”, “:'b”] + val scrut = mk_var ("x", sumSyntax.list_mk_sum tys) + fun gen_branch i (x : term) = “F” + + val tys = [“:'a”, “:'b”, “:'c”, “:'d”] + val scrut = mk_var ("x", sumSyntax.list_mk_sum tys) + fun gen_branch i (x : term) = if type_of x = “:'c” then mk_return x else mk_fail_failure “:'c” + + list_mk_sum_case scrut tys gen_branch + *) +(* For debugging *) +val list_mk_sum_case_case = ref (“T”, [] : (term * term) list) +(* +val (scrut, [(pat1, br1), (pat2, br2)]) = !list_mk_sum_case_case +*) +fun list_mk_sum_case (scrut : term) (tys : hol_type list) + (gen_branch : int -> term -> term) : term = + let + (* Create the cases. Note that without sugar, the match actually looks like this: + {[ + case x of + | INL y0 => ... (* Branch of index 0 *) + | INR x1 + case x1 of + | INL y1 => ... (* Branch of index 1 *) + | INR x2 => + case x2 of + | INL y2 => ... (* Branch of index 2 *) + | INR y3 => ... (* Branch of index 3 *) + ]} + *) + fun create_case (j : int) (scrut : term) (tys : hol_type list) : term = + let + val _ = print_dbg ("list_mk_sum_case: " ^ + String.concatWith ", " (map type_to_string tys) ^ "\n") + in + case tys of + [] => failwith "tys is too short" + | [ ty ] => + (* Last element: no match to perform *) + gen_branch j scrut + | ty1 :: tys => + (* Not last: we create a pattern: + {[ + case scrut of + | INL pat_var1 => ... (* Branch of index i *) + | INR pat_var2 => + ... (* Generate this term recursively *) + ]} + *) + let + (* INL branch *) + val after_ty = sumSyntax.list_mk_sum tys + val pat_var1 = genvar ty1 + val pat1 = sumSyntax.mk_inl (pat_var1, after_ty) + val br1 = gen_branch j pat_var1 + (* INR branch *) + val pat_var2 = genvar after_ty + val pat2 = sumSyntax.mk_inr (pat_var2, ty1) + val br2 = create_case (j+1) pat_var2 tys + val _ = print_dbg ("list_mk_sum_case: assembling:\n" ^ + term_to_string scrut ^ ",\n" ^ + "[(" ^ term_to_string pat1 ^ ",\n " ^ term_to_string br1 ^ "),\n\n" ^ + " (" ^ term_to_string pat2 ^ ",\n " ^ term_to_string br2 ^ ")]\n\n") + val case_elems = (scrut, [(pat1, br1), (pat2, br2)]) + val _ = list_mk_sum_case_case := case_elems + in + (* Put everything together *) + TypeBase.mk_case case_elems + end + end + in + create_case 0 scrut tys + end + +(* Generate a ‘case ... of’ to select the input/output of the ith variant of + the param enumeration. + + Ex.: + ==== + There are two functions in the group, and we select the input of the function of index 1: + {[ + case x of + | INL _ => Fail Failure (* Input of function of index 0 *) + | INR (INL _) => Fail Failure (* Output of function of index 0 *) + | INR (INR (INL y)) => Return y (* Input of the function of index 1: select this one *) + | INR (INR (INR _)) => Fail Failure (* Output of the function of index 1 *) + ]} + + (* Debug *) + val tys = [(“:'a”, “:'b”)] + val scrut = “x : 'a + 'b” + val fi = 0 + val is_input = true + + val tys = [(“:'a”, “:'b”), (“:'c”, “:'d”)] + val scrut = “x : 'a + 'b + 'c + 'd” + val fi = 1 + val is_input = false + + val scrut = mk_var ("x", sumSyntax.list_mk_sum (flatten tys)) + + list_mk_case_select scrut tys fi is_input + *) +fun list_mk_case_sum_select (scrut : term) (tys : (hol_type * hol_type) list) + (fi : int) (is_input : bool) : term = + let + (* The index of the element in the enumeration that we will select *) + val i = 2 * fi + (if is_input then 0 else 1) + (* Flatten the types and numerotate them *) + fun flatten ls = List.concat (map (fn (x, y) => [x, y]) ls) + val tys = flatten tys + (* Get the return type *) + val ret_ty = List.nth (tys, i) + (* The continuation which will generate the content of the branches *) + fun gen_branch j var = if j = i then mk_return var else mk_fail_failure ret_ty + in + (* Generate the ‘case ... of’ *) + list_mk_sum_case scrut tys gen_branch + end + +(* Generate a ‘case ... of’ to select the input/output of the ith variant of + the param enumeration. + + Ex.: + ==== + There are two functions in the group, and we select the input of the function of index 1: + {[ + case x of + | Fail e => Fail e + | Diverge => Diverge + | Return r => + case r of + | INL _ => Fail Failure (* Input of function of index 0 *) + | INR (INL _) => Fail Failure (* Output of function of index 0 *) + | INR (INR (INL y)) => Return y (* Input of the function of index 1: select this one *) + | INR (INR (INR _)) => Fail Failure (* Output of the function of index 1 *) + ]} + *) +fun mk_case_select_result_sum (scrut : term) (tys : (hol_type * hol_type) list) + (fi : int) (is_input : bool) : term = + (* We match over the result, then over the enumeration *) + mk_result_case scrut (fn x => list_mk_case_sum_select x tys fi is_input) + +(* +val scrut = call +val tys = in_out_tys +val is_input = false +val call = mk_case_select_result_sum call in_out_tys fi false +*) + +(* TODO: move *) +fun enumerate (ls : 'a list) : (int * 'a) list = + zip (List.tabulate (List.length ls, fn i => i)) ls + +(* Generate a body for the fixed-point operator from a quoted group of mutually + recursive definitions. + + See TODO for detailed explanations: from the quoted equations for ‘nth’ + (or for [‘even’, ‘odd’]) we generate the body ‘nth_body’ (or ‘even_odd_body’, + respectively). + *) +fun mk_body (fnames : string list) (in_out_tys : (hol_type * hol_type) list) + (def_tms : term list) : term = + let + val fnames_set = Redblackset.fromList String.compare fnames + + (* Compute a map from function name to function index *) + val fnames_map = Redblackmap.fromList String.compare + (map (fn (x, y) => (y, x)) (enumerate fnames)) + + (* Compute the input/output type, that we dub the "parameter type" *) + fun flatten ls = List.concat (map (fn (x, y) => [x, y]) ls) + val param_type = sumSyntax.list_mk_sum (flatten in_out_tys) + + (* Introduce a variable for the confinuation *) + val fcont = genvar (param_type --> mk_result param_type) + + (* In the function equations, replace all the recursive calls with calls to the continuation. + + When replacing a recursive call, we have to do two things: + - we need to inject the input parameters into the parameter type + Ex.: + - ‘nth tl i’ becomes ‘f (INL (tl, i))’ where ‘f’ is the continuation + - ‘even i’ becomes ‘f (INL i)’ where ‘f’ is the continuation + - we need to wrap the the call to the continuation into a ‘case ... of’ + to extract its output (we need to make sure that the transformation + preserves the type of the expression!) + Ex.: ‘nth tl i’ becomes: + {[ + case f (INL (tl, i)) of + | Fail e => Fail e + | Diverge => Diverge + | Return r => + case r of + | INL _ => Fail Failure + | INR x => Return (INR x) + ]} + *) + (* For debugging *) + val replace_rec_calls_rec_call_tm = ref “T” + fun replace_rec_calls (fnames_set : string Redblackset.set) (tm : term) : term = + let + val _ = print_dbg ("replace_rec_calls: original expression:\n" ^ + term_to_string tm ^ "\n\n") + val ntm = + case dest_term tm of + VAR (name, ty) => + (* Check that this is not one of the functions in the group - remark: + we could handle that by introducing lambdas. + *) + if Redblackset.member (fnames_set, name) + then failwith ("mk_body: not well-formed definition: found " ^ name ^ + " in an improper position") + else tm + | CONST _ => tm + | LAMB (x, tm) => + let + (* The variable might shadow one of the functions *) + val fnames_set = Redblackset.delete (fnames_set, (fst o dest_var) x) + (* Update the term in the lambda *) + val tm = replace_rec_calls fnames_set tm + in + (* Reconstruct *) + mk_abs (x, tm) + end + | COMB (_, _) => + let + (* Completely destruct the application, check if this is a recursive call *) + val (app, args) = strip_comb tm + val is_rec_call = Redblackset.member (fnames_set, (fst o dest_var) app) + handle HOL_ERR _ => false + (* Whatever the case, apply the transformation to all the inputs *) + val args = map (replace_rec_calls fnames_set) args + in + (* If this is not a recursive call: apply the transformation to all the + terms. Otherwise, replace. *) + if not is_rec_call then list_mk_comb (replace_rec_calls fnames_set app, args) + else + (* Rec call: replace *) + let + val _ = replace_rec_calls_rec_call_tm := tm + (* First, find the index of the function *) + val fname = (fst o dest_var) app + val fi = Redblackmap.find (fnames_map, fname) + (* Inject the input values into the param type *) + val input = pairSyntax.list_mk_pair args + val input = inject_in_param_sum in_out_tys fi true input + (* Create the recursive call *) + val call = mk_comb (fcont, input) + (* Wrap the call into a ‘case ... of’ to extract the output *) + val call = mk_case_select_result_sum call in_out_tys fi false + in + (* Return *) + call + end + end + val _ = print_dbg ("replace_rec_calls: new expression:\n" ^ term_to_string ntm ^ "\n\n") + in + ntm + end + handle HOL_ERR e => + let + val _ = print_dbg ("replace_rec_calls: failed on:\n" ^ term_to_string tm ^ "\n\n") + in + raise (HOL_ERR e) + end + fun replace_rec_calls_in_eq (eq : term) : term = + let + val (l, r) = dest_eq eq + in + mk_eq (l, replace_rec_calls fnames_set r) + end + val def_tms_with_fcont = map replace_rec_calls_in_eq def_tms + + (* Wrap all the function bodies to inject their result into the param type. + + We collect the function inputs at the same time, because they will be + grouped into a tuple that we will have to deconstruct. + *) + fun inject_body_to_enums (i : int, def_eq : term) : term list * term = + let + val (l, body) = dest_eq def_eq + val (_, args) = strip_comb l + (* We have the deconstruct the result, then, in the ‘Return’ branch, + properly wrap the returned value *) + val body = mk_result_case body (fn x => mk_return (inject_in_param_sum in_out_tys i false x)) + in + (args, body) + end + val def_tms_inject = map inject_body_to_enums (enumerate def_tms_with_fcont) + + (* Currify the body inputs. + + For instance, if the body has inputs: ‘x’, ‘y’; we return the following: + {[ + (‘z’, ‘case z of (x, y) => ... (* body *) ’) + ]} + where ‘z’ is fresh. + + We return: (curried input, body). + + (* Debug *) + val body = “(x:'a, y:'b, z:'c)” + val args = [“x:'a”, “y:'b”, “z:'c”] + currify_body_inputs (args, body) + *) + fun currify_body_inputs (args : term list, body : term) : term * term = + let + fun mk_curry (args : term list) (body : term) : term * term = + case args of + [] => failwith "no inputs" + | [x] => (x, body) + | x1 :: args => + let + val (x2, body) = mk_curry args body + val scrut = genvar (pairSyntax.list_mk_prod (map type_of (x1 :: args))) + val pat = pairSyntax.mk_pair (x1, x2) + val br = body + in + (scrut, TypeBase.mk_case (scrut, [(pat, br)])) + end + in + mk_curry args body + end + val def_tms_currified = map currify_body_inputs def_tms_inject + + (* Group all the functions into a single body, with an outer ‘case .. of’ + which selects the appropriate body depending on the input *) + val param_ty = sumSyntax.list_mk_sum (flatten in_out_tys) + val input = genvar param_ty + fun mk_mut_rec_body_branch (i : int) (patvar : term) : term = + (* Case disjunction on whether the branch is for an input value (in + which case we call the proper body) or an output value (in which + case we return ‘Fail ...’ *) + if i mod 2 = 0 then + let + val fi = i div 2 + val (x, def_tm) = List.nth (def_tms_currified, fi) + (* The variable in the pattern and the variable expected by the + body may not be the same: we introduce a let binding *) + val def_tm = mk_let (mk_abs (x, def_tm), patvar) + in + def_tm + end + else + (* Output value: fail *) + mk_fail_failure param_ty + val mut_rec_body = list_mk_sum_case input (flatten in_out_tys) mk_mut_rec_body_branch + + + (* Abstract away the parameters to produce the final body of the fixed point *) + val mut_rec_body = list_mk_abs ([fcont, input], mut_rec_body) + in + mut_rec_body + end + +(* For explanations about the different steps, see TODO *) +fun DefineDiv (def_qt : term quotation) = + let + (* Parse the definitions. + + Example: + {[ + (even (i : int) : bool result = if i = 0 then Return T else odd (i - 1)) /\ + (odd (i : int) : bool result = if i = 0 then Return F else even (i - 1)) + ]} + *) + val def_tms = (strip_conj o list_mk_conj o rev) (Defn.parse_quote def_qt) + + (* Compute the names and the input/output types of the functions *) + fun compute_names_in_out_tys (tm : term) : string * (hol_type * hol_type) = + let + val app = lhs tm + val name = (fst o dest_var o fst o strip_comb) app + val out_ty = dest_result (type_of app) + val input_tys = pairSyntax.list_mk_prod (map type_of ((snd o strip_comb) app)) + in + (name, (input_tys, out_ty)) + end + val (fnames, in_out_tys) = unzip (map compute_names_in_out_tys def_tms) + + (* Generate the body. + + See the comments at the beginning of the file (lookup "BODY GENERATION"). + *) + val body = mk_body fnames in_out_tys def_tms + + (* Prove that the body satisfies the validity property required by the fixed point *) + val body_is_valid = prove_body_is_valid body + + (* Generate the definitions for the various functions by using the fixed point + and the body. *) + val raw_defs = mk_raw_defs in_out_tys def_tms body_is_valid + + (* Prove the final equations *) + val def_eqs = prove_def_eqs body_is_valid def_tms raw_defs + in + def_eqs + end + +val [even_def, odd_def] = DefineDiv ‘ + (even (i : int) : bool result = + if i = 0 then Return T else odd (i - 1)) /\ + (odd (i : int) : bool result = + if i = 0 then Return F else even (i - 1)) +’ + +val [nth_def] = DefineDiv ‘ + nth (ls : 't list_t) (i : u32) : 't result = + case ls of + | ListCons x tl => + if u32_to_int i = (0:int) + then (Return x) + else + do + i0 <- u32_sub i (int_to_u32 1); + nth tl i0 + od + | ListNil => Fail Failure +’ + +val even_odd_body_def = Define ‘ + even_odd_body + (* The body takes a continuation - required by the fixed-point operator *) + (f : (int + bool + int + bool) -> (int + bool + int + bool) result) + (* The type of the input/output is: + input of even + output of even + input of odd + output of odd + *) + (x : int + bool + int + bool) : + (int + bool + int + bool) result = + (* Case disjunction over the input, in order to figure out which + function from the group is actually called (even , or odd). *) + case x of + | INL i => (* Input of even *) + (* Body of even *) + if i = 0 then Return (INR (INL T)) + else + (* Recursive calls are calls to the continuation f, wrapped + in the proper variant: here we call odd *) + (case f (INR (INR (INL (i - 1)))) of + | Fail e => Fail e + | Diverge => Diverge + | Return r => + (* Eliminate the unwanted results *) + case r of + | INL _ => Fail Failure + | INR (INL _) => Fail Failure + | INR (INR (INL _)) => Fail Failure + | INR (INR (INR b)) => (* Extract the output of odd *) + (* Inject into the variant for the output of even *) + Return (INR (INL b)) + ) + | INR (INL _) => (* Output of even *) + (* We must ignore this one *) + Fail Failure + | INR (INR (INL i)) => + (* Body of odd *) + if i = 0 then Return (INR (INR (INR F))) + else + (* Call to even *) + (case f (INL (i - 1)) of + | Fail e => Fail e + | Diverge => Diverge + | Return r => + (* Eliminate the unwanted results *) + case r of + | INL _ => Fail Failure + | INR (INL b) => (* Extract the output of even *) + (* Inject into the variant for the output of odd *) + Return (INR (INR (INR b))) + | INR (INR (INL _)) => Fail Failure + | INR (INR (INR _)) => Fail Failure + ) + | INR (INR (INR _)) => (* Output of odd *) + (* We must ignore this one *) + Fail Failure +’ + +Theorem even_odd_body_is_valid_aux: + is_valid_fp_body (SUC (SUC n)) even_odd_body +Proof + prove_body_is_valid_tac (SOME even_odd_body_def) +QED + +Theorem even_odd_body_is_valid: + is_valid_fp_body (SUC (SUC 0)) even_odd_body +Proof + irule even_odd_body_is_valid_aux +QED + +val even_raw_def = Define ‘ + even (i : int) = + case fix even_odd_body (INL i) of + | Fail e => Fail e + | Diverge => Diverge + | Return r => + case r of + | INL _ => Fail Failure + | INR (INL b) => Return b + | INR (INR (INL _)) => Fail Failure + | INR (INR (INR _)) => Fail Failure +’ + +val odd_raw_def = Define ‘ + odd (i : int) = + case fix even_odd_body (INR (INR (INL i))) of + | Fail e => Fail e + | Diverge => Diverge + | Return r => + case r of + | INL _ => Fail Failure + | INR (INL b) => Fail Failure + | INR (INR (INL _)) => Fail Failure + | INR (INR (INR b)) => Return b +’ + +Theorem even_def: + ∀i. even (i : int) : bool result = + if i = 0 then Return T else odd (i - 1) +Proof + prove_def_eq_tac even_raw_def [even_raw_def, odd_raw_def] even_odd_body_is_valid even_odd_body_def +QED + +Theorem odd_def: + ∀i. odd (i : int) : bool result = + if i = 0 then Return F else even (i - 1) +Proof + prove_def_eq_tac odd_raw_def [even_raw_def, odd_raw_def] even_odd_body_is_valid even_odd_body_def +QED + +val _ = export_theory () diff --git a/backends/hol4/primitivesBaseTacLib.sml b/backends/hol4/primitivesBaseTacLib.sml index fe87e894..78822abe 100644 --- a/backends/hol4/primitivesBaseTacLib.sml +++ b/backends/hol4/primitivesBaseTacLib.sml @@ -47,11 +47,14 @@ val pat_undisch_tac = Q.PAT_UNDISCH_TAC val equiv_is_imp = prove (“∀x y. ((x ⇒ y) ∧ (y ⇒ x)) ⇒ (x ⇔ y)”, metis_tac []) val equiv_tac = irule equiv_is_imp >> conj_tac +(* Rewrite the goal once, and on the left part of the goal seen as an application *) +fun pure_once_rewrite_left_tac ths = + CONV_TAC (PATH_CONV "l" (PURE_ONCE_REWRITE_CONV ths)) + (* Dependent rewrites *) val dep_pure_once_rewrite_tac = dep_rewrite.DEP_PURE_ONCE_REWRITE_TAC val dep_pure_rewrite_tac = dep_rewrite.DEP_PURE_REWRITE_TAC - (* Add a list of theorems in the assumptions *) fun assume_tacl (thms : thm list) : tactic = let |