open HolKernel boolLib bossLib Parse open boolTheory arithmeticTheory integerTheory intLib listTheory stringTheory open primitivesArithTheory primitivesBaseTacLib ilistTheory primitivesTheory open primitivesLib open divDefProto2Theory val _ = new_theory "divDefProto2TestScript" (*====================== * Example 1: nth *======================*) Datatype: list_t = ListCons 't list_t | ListNil End (* We use this version of the body to prove that the body is valid *) val nth_body_def = Define ‘ nth_body (f : (('t list_t # u32) + 't) -> (('t list_t # u32) + 't) result) (x : (('t list_t # u32) + 't)) : (('t list_t # u32) + 't) result = (* Destruct the input. We need this to call the proper function in case of mutually recursive definitions, but also to eliminate arguments which correspond to the output value (the input type is the same as the output type). *) case x of | INL x => ( let (ls, i) = x in case ls of | ListCons x tl => if u32_to_int i = (0:int) then Return (INR x) else do i0 <- u32_sub i (int_to_u32 1); r <- f (INL (tl, i0)); (* Eliminate the invalid outputs. This is not necessary here, but it is in the case of non tail call recursive calls. *) case r of | INL _ => Fail Failure | INR i1 => Return (INR i1) od | ListNil => Fail Failure) | INR _ => Fail Failure ’ val dbg = ref false fun print_dbg s = if (!dbg) then print s else () (* Tactic which makes progress in a proof of validity by making a case disjunction (we use this to explore all the paths in a function body). *) fun prove_valid_case_progress (* val (asms, g) = top_goal () *) (* Tactic to prove that a body is valid: perform one step. *) fun prove_body_is_valid_tac_step (asms, g) = let (* The goal has the shape: {[ (∀g h. ... g x = ... h x) ∨ ∃h y. is_valid_fp_body n h ∧ ∀g. ... g x = ... od ]} *) (* Retrieve the scrutinee in the goal (‘x’). There are two cases: - either the function has the shape: {[ (λ(y,z). ...) x ]} in which case we need to destruct ‘x’ - or we have a normal ‘case ... of’ *) val body = (lhs o snd o strip_forall o fst o dest_disj) g val scrut = let val (app, x) = dest_comb body val (app, _) = dest_comb app val {Name=name, Thy=thy, Ty = _ } = dest_thy_const app in if thy = "pair" andalso name = "UNCURRY" then x else failwith "not a curried argument" end handle HOL_ERR _ => strip_all_cases_get_scrutinee body (* Retrieve the first quantified continuations from the goal (‘g’) *) val qc = (hd o fst o strip_forall o fst o dest_disj) g (* Check if the scrutinee is a recursive call *) val (scrut_app, _) = strip_comb scrut val _ = print_dbg ("prove_body_is_valid_step: Scrutinee: " ^ term_to_string scrut ^ "\n") (* For the recursive calls: *) fun step_rec () = let val _ = print_dbg ("prove_body_is_valid_step: rec call\n") (* We need to instantiate the ‘h’ existantially quantified function *) (* First, retrieve the body of the function: it is given by the ‘Return’ branch *) val (_, _, branches) = TypeBase.dest_case body (* Find the branch corresponding to the return *) val ret_branch = List.find (fn (pat, _) => let val {Name=name, Thy=thy, Ty = _ } = (dest_thy_const o fst o strip_comb) pat in thy = "primitives" andalso name = "Return" end) branches val var = (hd o snd o strip_comb o fst o valOf) ret_branch val br = (snd o valOf) ret_branch (* Abstract away the input variable introduced by the pattern and the continuation ‘g’ *) val h = list_mk_abs ([qc, var], br) val _ = print_dbg ("prove_body_is_valid_step: h: " ^ term_to_string h ^ "\n") (* Retrieve the input parameter ‘x’ *) val input = (snd o dest_comb) scrut val _ = print_dbg ("prove_body_is_valid_step: y: " ^ term_to_string input ^ "\n") in ((* Choose the right possibility (this is a recursive call) *) disj2_tac >> (* Instantiate the quantifiers *) qexists ‘^h’ >> qexists ‘^input’ >> (* Unfold the predicate once *) pure_once_rewrite_tac [is_valid_fp_body_def] >> (* We have two subgoals: - we have to prove that ‘h’ is valid - we have to finish the proof of validity for the current body *) conj_tac >> fs [case_result_switch_eq]) end in (* If recursive call: special treatment. Otherwise, we do a simple disjunction *) (if term_eq scrut_app qc then step_rec () else (Cases_on ‘^scrut’ >> fs [case_result_switch_eq])) (asms, g) end (* Tactic to prove that a body is valid *) fun prove_body_is_valid_tac (body_def : thm option) : tactic = let val body_def_thm = case body_def of SOME th => [th] | NONE => [] in pure_once_rewrite_tac [is_valid_fp_body_def] >> gen_tac >> (* Expand *) fs body_def_thm >> fs [bind_def, case_result_switch_eq] >> (* Explore the body *) rpt prove_body_is_valid_tac_step end (* TODO: move *) val is_valid_fp_body_tm = “is_valid_fp_body” (* Prove that a body satisfies the validity condition of the fixed point *) fun prove_body_is_valid (body : term) : thm = let (* Explore the body and count the number of occurrences of nested recursive calls so that we can properly instantiate the ‘N’ argument of ‘is_valid_fp_body’. We first retrieve the name of the continuation parameter. Rem.: we generated fresh names so that, for instance, the continuation name doesn't collide with other names. Because of this, we don't need to look for collisions when exploring the body (and in the worst case, we would cound an overapproximation of the number of recursive calls, which is perfectly valid). *) val fcont = (hd o fst o strip_abs) body val fcont_name = (fst o dest_var) fcont fun max x y = if x > y then x else y fun count_body_rec_calls (body : term) : int = case dest_term body of VAR (name, _) => if name = fcont_name then 1 else 0 | CONST _ => 0 | COMB (x, y) => max (count_body_rec_calls x) (count_body_rec_calls y) | LAMB (_, x) => count_body_rec_calls x val num_rec_calls = count_body_rec_calls body (* Generate the term ‘SUC (SUC ... (SUC n))’ where ‘n’ is a fresh variable. Remark: we first prove ‘is_valid_fp_body (SUC ... n) body’ then substitue ‘n’ with ‘0’ to prevent the quantity from being rewritten to a bit representation, which would prevent unfolding of the ‘is_valid_fp_body’. *) val nvar = genvar num_ty (* Rem.: we stack num_rec_calls + 1 occurrences of ‘SUC’ (and the + 1 is important) *) fun mk_n i = if i = 0 then mk_suc nvar else mk_suc (mk_n (i-1)) val n_tm = mk_n num_rec_calls (* Generate the lemma statement *) val is_valid_tm = list_mk_icomb (is_valid_fp_body_tm, [n_tm, body]) val is_valid_thm = prove (is_valid_tm, prove_body_is_valid_tac NONE) (* Replace ‘nvar’ with ‘0’ *) val is_valid_thm = INST [nvar |-> zero_num_tm] is_valid_thm in is_valid_thm end (* val (asms, g) = top_goal () *) (* We first prove the theorem with ‘SUC (SUC n)’ where ‘n’ is a variable to prevent this quantity from being rewritten to 2 *) Theorem nth_body_is_valid_aux: is_valid_fp_body (SUC (SUC n)) nth_body Proof prove_body_is_valid_tac (SOME nth_body_def) QED Theorem nth_body_is_valid: is_valid_fp_body (SUC (SUC 0)) nth_body Proof irule nth_body_is_valid_aux QED val nth_raw_def = Define ‘ nth (ls : 't list_t) (i : u32) = case fix nth_body (INL (ls, i)) of | Fail e => Fail e | Diverge => Diverge | Return r => case r of | INL _ => Fail Failure | INR x => Return x ’ val fix_tm = “fix” (* Generate the raw definitions by using the grouped definition body and the fixed point operator *) fun mk_raw_defs (in_out_tys : (hol_type * hol_type) list) (def_tms : term list) (body_is_valid : thm) : thm list = let (* Retrieve the body *) val body = (List.last o snd o strip_comb o concl) body_is_valid (* Create the term ‘fix body’ *) val fixed_body = mk_icomb (fix_tm, body) (* For every function in the group, generate the equation that we will use as definition. In particular: - add the properly injected input ‘x’ to ‘fix body’ (ex.: for ‘nth ls i’ we add the input ‘INL (ls, i)’) - wrap ‘fix body x’ into a case disjunction to extract the relevant output For instance, in the case of ‘nth ls i’: {[ nth (ls : 't list_t) (i : u32) = case fix nth_body (INL (ls, i)) of | Fail e => Fail e | Diverge => Diverge | Return r => case r of | INL _ => Fail Failure | INR x => Return x ]} *) fun mk_def_eq (i : int, def_tm : term) : term = let (* Retrieve the lhs of the original definition equation, and in particular the inputs *) val def_lhs = lhs def_tm val args = (snd o strip_comb) def_lhs (* Inject the inputs into the param type *) val input = pairSyntax.list_mk_pair args val input = inject_in_param_sum in_out_tys i true input (* Compose*) val def_rhs = mk_comb (fixed_body, input) (* Wrap in the case disjunction *) val def_rhs = mk_case_select_result_sum def_rhs in_out_tys i false (* Create the equation *) val def_eq_tm = mk_eq (def_lhs, def_rhs) in def_eq_tm end val raw_def_tms = map mk_def_eq (enumerate def_tms) (* Generate the definitions *) val raw_defs = map (fn tm => Define ‘^tm’) raw_def_tms in raw_defs end (* Tactic which makes progress in a proof by making a case disjunction (we use this to explore all the paths in a function body). *) fun case_progress (asms, g) = let val scrut = (strip_all_cases_get_scrutinee o lhs) g in Cases_on ‘^scrut’ (asms, g) end (* Prove the final equation, that we will use as definition. *) fun prove_def_eq_tac (current_raw_def : thm) (all_raw_defs : thm list) (is_valid : thm) (body_def : thm option) : tactic = let val body_def_thm = case body_def of SOME th => [th] | NONE => [] in rpt gen_tac >> (* Expand the definition *) pure_once_rewrite_tac [current_raw_def] >> (* Use the fixed-point equality *) pure_once_rewrite_left_tac [HO_MATCH_MP fix_fixed_eq is_valid] >> (* Expand the body definition *) pure_rewrite_tac body_def_thm >> (* Expand all the definitions from the group *) pure_rewrite_tac all_raw_defs >> (* Explore all the paths - maybe we can be smarter, but this is fast and really easy *) fs [bind_def] >> rpt (case_progress >> fs []) end (* Prove the final equations that we will give to the user as definitions *) fun prove_def_eqs (body_is_valid : thm) (def_tms : term list) (raw_defs : thm list) : thm list= let val defs_tgt_raw = zip def_tms raw_defs (* Substitute the function variables with the constants introduced in the raw definitions *) fun compute_fsubst (def_tm, raw_def) : {redex: term, residue: term} = let val (fvar, _) = (strip_comb o lhs) def_tm val fconst = (fst o strip_comb o lhs o snd o strip_forall o concl) raw_def in (fvar |-> fconst) end val fsubst = map compute_fsubst defs_tgt_raw val defs_tgt_raw = map (fn (x, y) => (subst fsubst x, y)) defs_tgt_raw fun prove_def_eq (def_tm, raw_def) : thm = let (* Quantify the parameters *) val (_, params) = (strip_comb o lhs) def_tm val def_eq_tm = list_mk_forall (params, def_tm) (* Prove *) val def_eq = prove (def_eq_tm, prove_def_eq_tac raw_def raw_defs body_is_valid NONE) in def_eq end val def_eqs = map prove_def_eq defs_tgt_raw in def_eqs end Theorem nth_def: ∀ls i. nth (ls : 't list_t) (i : u32) : 't result = case ls of | ListCons x tl => if u32_to_int i = (0:int) then (Return x) else do i0 <- u32_sub i (int_to_u32 1); nth tl i0 od | ListNil => Fail Failure Proof prove_def_eq_tac nth_raw_def [nth_raw_def] nth_body_is_valid nth_body_def QED (*====================== * Example 2: even, odd *======================*) val def_qt = ‘ (even (i : int) : bool result = if i = 0 then Return T else odd (i - 1)) /\ (odd (i : int) : bool result = if i = 0 then Return F else even (i - 1)) ’ val result_ty = “:'a result” val error_ty = “:error” val alpha_ty = “:'a” val num_ty = “:num” val return_tm = “Return : 'a -> 'a result” val fail_tm = “Fail : error -> 'a result” val fail_failure_tm = “Fail Failure : 'a result” val diverge_tm = “Diverge : 'a result” val zero_num_tm = “0:num” val suc_tm = “SUC” fun mk_result (ty : hol_type) : hol_type = Type.type_subst [ alpha_ty |-> ty ] result_ty fun dest_result (ty : hol_type) : hol_type = let val {Args=out_ty, Thy=thy, Tyop=tyop} = dest_thy_type ty in if thy = "primitives" andalso tyop = "result" then hd out_ty else failwith "dest_result: not a result" end fun mk_return (x : term) : term = mk_icomb (return_tm, x) fun mk_fail (ty : hol_type) (e : term) : term = mk_comb (inst [ alpha_ty |-> ty ] fail_tm, e) fun mk_fail_failure (ty : hol_type) : term = inst [ alpha_ty |-> ty ] fail_failure_tm fun mk_diverge (ty : hol_type) : term = inst [ alpha_ty |-> ty ] diverge_tm fun mk_suc (n : term) = mk_comb (suc_tm, n) (* *) (* **BODY GENERATION**: ==================== When generating a recursive definition, we apply a fixed-point operator to a function body. In case we define a group of mutually recursive definitions, we generate *one* single body for the whole group of definitions. It works as follows. The input of the body is an enumeration: we start by branching over this input, and every branch corresponds to one function in the mutually recursive group. Also, the inputs must be grouped into tuples. Whenever we make a recursive call, we wrap the input parameters into the proper variant, so as to call the proper function. Moreover, the input of the body must have the same type as its output: we also store the outputs of the functions in some variants of the enumeration. In order to make this work, we need to shape the body so that: - input values/output values are properly injected into the enumeration - whenever we get an output value (which is an enumeration), we extract the value from the proper variant of the enumeration We encode the enumeration with a nested sum type, whose constructors are ‘INL’ and ‘INR’. Example: ======== We consider the following group of mutually recursive definitions: *) val even_odd_qt = Defn.parse_quote ‘ (even (i : int) : bool result = if i = 0 then Return T else odd (i - 1)) /\ (odd (i : int) : bool result = if i = 0 then Return F else even (i - 1)) ’ (* From those equations, we generate the following body: *) val even_odd_body_def = Define ‘ even_odd_body (* The body takes a continuation - required by the fixed-point operator *) (f : (int + bool + int + bool) -> (int + bool + int + bool) result) (* The type of the input is: input of even + output of even + input of odd + output of odd *) (x : int + bool + int + bool) : (* The output type is the same as the input type - this constraint comes from limitations in the way we can define the fixed-point operator inside the HOL logic *) (int + bool + int + bool) result = (* Case disjunction over the input, in order to figure out which function from the group is actually called (even, or odd). *) case x of | INL i => (* Input of even *) (* Body of even *) if i = 0 then Return (INR (INL T)) else (* Recursive calls are calls to the continuation f, wrapped in the proper variant: here we call odd *) (case f (INR (INR (INL (i - 1)))) of | Fail e => Fail e | Diverge => Diverge | Return r => (* Extract the proper value from the enumeration: here, the call is tail-call so this is not really necessary, but we might need to retrieve the output of the call to odd, which is a boolean, and do something more complex with it. *) case r of | INL _ => Fail Failure | INR (INL _) => Fail Failure | INR (INR (INL _)) => Fail Failure | INR (INR (INR b)) => (* Extract the output of odd *) (* Return: inject into the variant for the output of even *) Return (INR (INL b)) ) | INR (INL _) => (* Output of even *) (* We must ignore this one *) Fail Failure | INR (INR (INL i)) => (* Body of odd *) if i = 0 then Return (INR (INR (INR F))) else (* Call to even *) (case f (INL (i - 1)) of | Fail e => Fail e | Diverge => Diverge | Return r => (* Extract the proper value from the enumeration *) case r of | INL _ => Fail Failure | INR (INL b) => (* Extract the output of even *) (* Return: inject into the variant for the output of odd *) Return (INR (INR (INR b))) | INR (INR (INL _)) => Fail Failure | INR (INR (INR _)) => Fail Failure ) | INR (INR (INR _)) => (* Output of odd *) (* We must ignore this one *) Fail Failure ’ (* Small helper to generate wrappers of the shape: ‘INL x’, ‘INR (INL x)’, etc. Note that we should have: ‘length before_tys + 1 + length after tys >= 2’ Ex.: ==== The enumeration has type: “: 'a + 'b + 'c + 'd”. We want to generate the variant which injects “x:'c” into this enumeration. We need to split the list of types into: {[ before_tys = [“:'a”, “'b”] tm = “x: 'c” after_tys = [“:'d”] ]} The function will generate: {[ INR (INR (INL x) : 'a + 'b + 'c + 'd ]} (* Debug *) val before_tys = [“:'a”, “:'b”, “:'c”] val tm = “x:'d” val after_tys = [“:'e”, “:'f”] val before_tys = [“:'a”, “:'b”, “:'c”] val tm = “x:'d” val after_tys = [] mk_inl_inr_wrapper before_tys tm after_tys *) fun list_mk_inl_inr (before_tys : hol_type list) (tm : term) (after_tys : hol_type list) : term = let val (before_tys, pat) = if after_tys = [] then let val just_before_ty = List.last before_tys val before_tys = List.take (before_tys, List.length before_tys - 1) val pat = sumSyntax.mk_inr (tm, just_before_ty) in (before_tys, pat) end else (before_tys, sumSyntax.mk_inl (tm, sumSyntax.list_mk_sum after_tys)) val pat = foldr (fn (ty, pat) => sumSyntax.mk_inr (pat, ty)) pat before_tys in pat end (* This function wraps a term into the proper variant of the input/output sum. Ex.: ==== For the input of the first function, we generate: ‘INL x’ For the output of the first function, we generate: ‘INR (INL x)’ For the input of the 2nd function, we generate: ‘INR (INR (INL x))’ etc. If ‘is_input’ is true: we are wrapping an input. Otherwise we are wrapping an output. (* Debug *) val tys = [(“:'a”, “:'b”), (“:'c”, “:'d”), (“:'e”, “:'f”)] val j = 1 val tm = “x:'c” val tm = “y:'d” val is_input = true *) fun inject_in_param_sum (tys : (hol_type * hol_type) list) (j : int) (is_input : bool) (tm : term) : term = let fun flatten ls = List.concat (map (fn (x, y) => [x, y]) ls) val before_tys = flatten (List.take (tys, j)) val (input_ty, output_ty) = List.nth (tys, j) val after_tys = flatten (List.drop (tys, j + 1)) val (before_tys, after_tys) = if is_input then (before_tys, output_ty :: after_tys) else (before_tys @ [input_ty], after_tys) in list_mk_inl_inr before_tys tm after_tys end (* Remark: the order of the branches when creating matches is important. For instance, in the case of ‘result’ it must be: ‘Return’, ‘Fail’, ‘Diverge’. For the purpose of stability and maintainability, we introduce this small helper which reorders the cases in a pattern before actually creating the case expression. *) fun unordered_mk_case (scrut: term, pats: (term * term) list) : term = let (* Retrieve the constructors *) val cl = TypeBase.constructors_of (type_of scrut) (* Retrieve the names of the constructors *) val names = map (fst o dest_const) cl (* Use those to reorder the patterns *) fun is_pat (name : string) (pat, _) = let val app = (fst o strip_comb) pat val app_name = (fst o dest_const) app in app_name = name end val pats = map (fn name => valOf (List.find (is_pat name) pats)) names in (* Create the case *) TypeBase.mk_case (scrut, pats) end (* Wrap a term of type “:'a result” into a ‘case of’ which matches over the result. Ex.: ==== {[ f x ~~> case f x of | Fail e => Fail e | Diverge => Diverge | Return y => ... (* The branch content is generated by the continuation *) ]} ‘gen_ret_branch’ is a *continuation* which generates the content of the ‘Return’ branch (i.e., the content of the ‘...’ in the example above). It receives as input the value contained by the ‘Return’ (i.e., the variable ‘y’ in the example above). Remark.: the type of the term generated by ‘gen_ret_branch’ must have the type ‘result’, but it can change the content of the result (i.e., if ‘scrut’ has type ‘:'a result’, we can change the type of the wrapped expression to ‘:'b result’). (* Debug *) val scrut = “x: int result” fun gen_ret_branch x = mk_return x val scrut = “x: int result” fun gen_ret_branch _ = “Return T” mk_result_case scrut gen_ret_branch *) fun mk_result_case (scrut : term) (gen_ret_branch : term -> term) : term = let val scrut_ty = dest_result (type_of scrut) (* Return branch *) val ret_var = genvar scrut_ty val ret_pat = mk_return ret_var val ret_br = gen_ret_branch ret_var val ret_ty = dest_result (type_of ret_br) (* Failure branch *) val fail_var = genvar error_ty val fail_pat = mk_fail scrut_ty fail_var val fail_br = mk_fail ret_ty fail_var (* Diverge branch *) val div_pat = mk_diverge scrut_ty val div_br = mk_diverge ret_ty in unordered_mk_case (scrut, [(ret_pat, ret_br), (fail_pat, fail_br), (div_pat, div_br)]) end (* Generate a ‘case ... of’ over a sum type. Ex.: ==== If the scrutinee is: “x : 'a + 'b + 'c” (i.e., the tys list is: [“:'a”, “:b”, “:c”]), we generate: {[ case x of | INL y0 => ... (* Branch of index 0 *) | INR (INL y1) => ... (* Branch of index 1 *) | INR (INR (INL y2)) => ... (* Branch of index 2 *) | INR (INR (INR y3)) => ... (* Branch of index 3 *) ]} The content of the branches is generated by the ‘gen_branch’ continuation, which receives as input the index of the branch as well as the variable introduced by the pattern (in the example above: ‘y0’ for the branch 0, ‘y1’ for the branch 1, etc.) (* Debug *) val tys = [“:'a”, “:'b”] val scrut = mk_var ("x", sumSyntax.list_mk_sum tys) fun gen_branch i (x : term) = “F” val tys = [“:'a”, “:'b”, “:'c”, “:'d”] val scrut = mk_var ("x", sumSyntax.list_mk_sum tys) fun gen_branch i (x : term) = if type_of x = “:'c” then mk_return x else mk_fail_failure “:'c” list_mk_sum_case scrut tys gen_branch *) (* For debugging *) val list_mk_sum_case_case = ref (“T”, [] : (term * term) list) (* val (scrut, [(pat1, br1), (pat2, br2)]) = !list_mk_sum_case_case *) fun list_mk_sum_case (scrut : term) (tys : hol_type list) (gen_branch : int -> term -> term) : term = let (* Create the cases. Note that without sugar, the match actually looks like this: {[ case x of | INL y0 => ... (* Branch of index 0 *) | INR x1 case x1 of | INL y1 => ... (* Branch of index 1 *) | INR x2 => case x2 of | INL y2 => ... (* Branch of index 2 *) | INR y3 => ... (* Branch of index 3 *) ]} *) fun create_case (j : int) (scrut : term) (tys : hol_type list) : term = let val _ = print_dbg ("list_mk_sum_case: " ^ String.concatWith ", " (map type_to_string tys) ^ "\n") in case tys of [] => failwith "tys is too short" | [ ty ] => (* Last element: no match to perform *) gen_branch j scrut | ty1 :: tys => (* Not last: we create a pattern: {[ case scrut of | INL pat_var1 => ... (* Branch of index i *) | INR pat_var2 => ... (* Generate this term recursively *) ]} *) let (* INL branch *) val after_ty = sumSyntax.list_mk_sum tys val pat_var1 = genvar ty1 val pat1 = sumSyntax.mk_inl (pat_var1, after_ty) val br1 = gen_branch j pat_var1 (* INR branch *) val pat_var2 = genvar after_ty val pat2 = sumSyntax.mk_inr (pat_var2, ty1) val br2 = create_case (j+1) pat_var2 tys val _ = print_dbg ("list_mk_sum_case: assembling:\n" ^ term_to_string scrut ^ ",\n" ^ "[(" ^ term_to_string pat1 ^ ",\n " ^ term_to_string br1 ^ "),\n\n" ^ " (" ^ term_to_string pat2 ^ ",\n " ^ term_to_string br2 ^ ")]\n\n") val case_elems = (scrut, [(pat1, br1), (pat2, br2)]) val _ = list_mk_sum_case_case := case_elems in (* Put everything together *) TypeBase.mk_case case_elems end end in create_case 0 scrut tys end (* Generate a ‘case ... of’ to select the input/output of the ith variant of the param enumeration. Ex.: ==== There are two functions in the group, and we select the input of the function of index 1: {[ case x of | INL _ => Fail Failure (* Input of function of index 0 *) | INR (INL _) => Fail Failure (* Output of function of index 0 *) | INR (INR (INL y)) => Return y (* Input of the function of index 1: select this one *) | INR (INR (INR _)) => Fail Failure (* Output of the function of index 1 *) ]} (* Debug *) val tys = [(“:'a”, “:'b”)] val scrut = “x : 'a + 'b” val fi = 0 val is_input = true val tys = [(“:'a”, “:'b”), (“:'c”, “:'d”)] val scrut = “x : 'a + 'b + 'c + 'd” val fi = 1 val is_input = false val scrut = mk_var ("x", sumSyntax.list_mk_sum (flatten tys)) list_mk_case_select scrut tys fi is_input *) fun list_mk_case_sum_select (scrut : term) (tys : (hol_type * hol_type) list) (fi : int) (is_input : bool) : term = let (* The index of the element in the enumeration that we will select *) val i = 2 * fi + (if is_input then 0 else 1) (* Flatten the types and numerotate them *) fun flatten ls = List.concat (map (fn (x, y) => [x, y]) ls) val tys = flatten tys (* Get the return type *) val ret_ty = List.nth (tys, i) (* The continuation which will generate the content of the branches *) fun gen_branch j var = if j = i then mk_return var else mk_fail_failure ret_ty in (* Generate the ‘case ... of’ *) list_mk_sum_case scrut tys gen_branch end (* Generate a ‘case ... of’ to select the input/output of the ith variant of the param enumeration. Ex.: ==== There are two functions in the group, and we select the input of the function of index 1: {[ case x of | Fail e => Fail e | Diverge => Diverge | Return r => case r of | INL _ => Fail Failure (* Input of function of index 0 *) | INR (INL _) => Fail Failure (* Output of function of index 0 *) | INR (INR (INL y)) => Return y (* Input of the function of index 1: select this one *) | INR (INR (INR _)) => Fail Failure (* Output of the function of index 1 *) ]} *) fun mk_case_select_result_sum (scrut : term) (tys : (hol_type * hol_type) list) (fi : int) (is_input : bool) : term = (* We match over the result, then over the enumeration *) mk_result_case scrut (fn x => list_mk_case_sum_select x tys fi is_input) (* val scrut = call val tys = in_out_tys val is_input = false val call = mk_case_select_result_sum call in_out_tys fi false *) (* TODO: move *) fun enumerate (ls : 'a list) : (int * 'a) list = zip (List.tabulate (List.length ls, fn i => i)) ls (* Generate a body for the fixed-point operator from a quoted group of mutually recursive definitions. See TODO for detailed explanations: from the quoted equations for ‘nth’ (or for [‘even’, ‘odd’]) we generate the body ‘nth_body’ (or ‘even_odd_body’, respectively). *) fun mk_body (fnames : string list) (in_out_tys : (hol_type * hol_type) list) (def_tms : term list) : term = let val fnames_set = Redblackset.fromList String.compare fnames (* Compute a map from function name to function index *) val fnames_map = Redblackmap.fromList String.compare (map (fn (x, y) => (y, x)) (enumerate fnames)) (* Compute the input/output type, that we dub the "parameter type" *) fun flatten ls = List.concat (map (fn (x, y) => [x, y]) ls) val param_type = sumSyntax.list_mk_sum (flatten in_out_tys) (* Introduce a variable for the confinuation *) val fcont = genvar (param_type --> mk_result param_type) (* In the function equations, replace all the recursive calls with calls to the continuation. When replacing a recursive call, we have to do two things: - we need to inject the input parameters into the parameter type Ex.: - ‘nth tl i’ becomes ‘f (INL (tl, i))’ where ‘f’ is the continuation - ‘even i’ becomes ‘f (INL i)’ where ‘f’ is the continuation - we need to wrap the the call to the continuation into a ‘case ... of’ to extract its output (we need to make sure that the transformation preserves the type of the expression!) Ex.: ‘nth tl i’ becomes: {[ case f (INL (tl, i)) of | Fail e => Fail e | Diverge => Diverge | Return r => case r of | INL _ => Fail Failure | INR x => Return (INR x) ]} *) (* For debugging *) val replace_rec_calls_rec_call_tm = ref “T” fun replace_rec_calls (fnames_set : string Redblackset.set) (tm : term) : term = let val _ = print_dbg ("replace_rec_calls: original expression:\n" ^ term_to_string tm ^ "\n\n") val ntm = case dest_term tm of VAR (name, ty) => (* Check that this is not one of the functions in the group - remark: we could handle that by introducing lambdas. *) if Redblackset.member (fnames_set, name) then failwith ("mk_body: not well-formed definition: found " ^ name ^ " in an improper position") else tm | CONST _ => tm | LAMB (x, tm) => let (* The variable might shadow one of the functions *) val fnames_set = Redblackset.delete (fnames_set, (fst o dest_var) x) (* Update the term in the lambda *) val tm = replace_rec_calls fnames_set tm in (* Reconstruct *) mk_abs (x, tm) end | COMB (_, _) => let (* Completely destruct the application, check if this is a recursive call *) val (app, args) = strip_comb tm val is_rec_call = Redblackset.member (fnames_set, (fst o dest_var) app) handle HOL_ERR _ => false (* Whatever the case, apply the transformation to all the inputs *) val args = map (replace_rec_calls fnames_set) args in (* If this is not a recursive call: apply the transformation to all the terms. Otherwise, replace. *) if not is_rec_call then list_mk_comb (replace_rec_calls fnames_set app, args) else (* Rec call: replace *) let val _ = replace_rec_calls_rec_call_tm := tm (* First, find the index of the function *) val fname = (fst o dest_var) app val fi = Redblackmap.find (fnames_map, fname) (* Inject the input values into the param type *) val input = pairSyntax.list_mk_pair args val input = inject_in_param_sum in_out_tys fi true input (* Create the recursive call *) val call = mk_comb (fcont, input) (* Wrap the call into a ‘case ... of’ to extract the output *) val call = mk_case_select_result_sum call in_out_tys fi false in (* Return *) call end end val _ = print_dbg ("replace_rec_calls: new expression:\n" ^ term_to_string ntm ^ "\n\n") in ntm end handle HOL_ERR e => let val _ = print_dbg ("replace_rec_calls: failed on:\n" ^ term_to_string tm ^ "\n\n") in raise (HOL_ERR e) end fun replace_rec_calls_in_eq (eq : term) : term = let val (l, r) = dest_eq eq in mk_eq (l, replace_rec_calls fnames_set r) end val def_tms_with_fcont = map replace_rec_calls_in_eq def_tms (* Wrap all the function bodies to inject their result into the param type. We collect the function inputs at the same time, because they will be grouped into a tuple that we will have to deconstruct. *) fun inject_body_to_enums (i : int, def_eq : term) : term list * term = let val (l, body) = dest_eq def_eq val (_, args) = strip_comb l (* We have the deconstruct the result, then, in the ‘Return’ branch, properly wrap the returned value *) val body = mk_result_case body (fn x => mk_return (inject_in_param_sum in_out_tys i false x)) in (args, body) end val def_tms_inject = map inject_body_to_enums (enumerate def_tms_with_fcont) (* Currify the body inputs. For instance, if the body has inputs: ‘x’, ‘y’; we return the following: {[ (‘z’, ‘case z of (x, y) => ... (* body *) ’) ]} where ‘z’ is fresh. We return: (curried input, body). (* Debug *) val body = “(x:'a, y:'b, z:'c)” val args = [“x:'a”, “y:'b”, “z:'c”] currify_body_inputs (args, body) *) fun currify_body_inputs (args : term list, body : term) : term * term = let fun mk_curry (args : term list) (body : term) : term * term = case args of [] => failwith "no inputs" | [x] => (x, body) | x1 :: args => let val (x2, body) = mk_curry args body val scrut = genvar (pairSyntax.list_mk_prod (map type_of (x1 :: args))) val pat = pairSyntax.mk_pair (x1, x2) val br = body in (scrut, TypeBase.mk_case (scrut, [(pat, br)])) end in mk_curry args body end val def_tms_currified = map currify_body_inputs def_tms_inject (* Group all the functions into a single body, with an outer ‘case .. of’ which selects the appropriate body depending on the input *) val param_ty = sumSyntax.list_mk_sum (flatten in_out_tys) val input = genvar param_ty fun mk_mut_rec_body_branch (i : int) (patvar : term) : term = (* Case disjunction on whether the branch is for an input value (in which case we call the proper body) or an output value (in which case we return ‘Fail ...’ *) if i mod 2 = 0 then let val fi = i div 2 val (x, def_tm) = List.nth (def_tms_currified, fi) (* The variable in the pattern and the variable expected by the body may not be the same: we introduce a let binding *) val def_tm = mk_let (mk_abs (x, def_tm), patvar) in def_tm end else (* Output value: fail *) mk_fail_failure param_ty val mut_rec_body = list_mk_sum_case input (flatten in_out_tys) mk_mut_rec_body_branch (* Abstract away the parameters to produce the final body of the fixed point *) val mut_rec_body = list_mk_abs ([fcont, input], mut_rec_body) in mut_rec_body end (* For explanations about the different steps, see TODO *) fun DefineDiv (def_qt : term quotation) = let (* Parse the definitions. Example: {[ (even (i : int) : bool result = if i = 0 then Return T else odd (i - 1)) /\ (odd (i : int) : bool result = if i = 0 then Return F else even (i - 1)) ]} *) val def_tms = (strip_conj o list_mk_conj o rev) (Defn.parse_quote def_qt) (* Compute the names and the input/output types of the functions *) fun compute_names_in_out_tys (tm : term) : string * (hol_type * hol_type) = let val app = lhs tm val name = (fst o dest_var o fst o strip_comb) app val out_ty = dest_result (type_of app) val input_tys = pairSyntax.list_mk_prod (map type_of ((snd o strip_comb) app)) in (name, (input_tys, out_ty)) end val (fnames, in_out_tys) = unzip (map compute_names_in_out_tys def_tms) (* Generate the body. See the comments at the beginning of the file (lookup "BODY GENERATION"). *) val body = mk_body fnames in_out_tys def_tms (* Prove that the body satisfies the validity property required by the fixed point *) val body_is_valid = prove_body_is_valid body (* Generate the definitions for the various functions by using the fixed point and the body. *) val raw_defs = mk_raw_defs in_out_tys def_tms body_is_valid (* Prove the final equations *) val def_eqs = prove_def_eqs body_is_valid def_tms raw_defs in def_eqs end val [even_def, odd_def] = DefineDiv ‘ (even (i : int) : bool result = if i = 0 then Return T else odd (i - 1)) /\ (odd (i : int) : bool result = if i = 0 then Return F else even (i - 1)) ’ val [nth_def] = DefineDiv ‘ nth (ls : 't list_t) (i : u32) : 't result = case ls of | ListCons x tl => if u32_to_int i = (0:int) then (Return x) else do i0 <- u32_sub i (int_to_u32 1); nth tl i0 od | ListNil => Fail Failure ’ val even_odd_body_def = Define ‘ even_odd_body (* The body takes a continuation - required by the fixed-point operator *) (f : (int + bool + int + bool) -> (int + bool + int + bool) result) (* The type of the input/output is: input of even + output of even + input of odd + output of odd *) (x : int + bool + int + bool) : (int + bool + int + bool) result = (* Case disjunction over the input, in order to figure out which function from the group is actually called (even , or odd). *) case x of | INL i => (* Input of even *) (* Body of even *) if i = 0 then Return (INR (INL T)) else (* Recursive calls are calls to the continuation f, wrapped in the proper variant: here we call odd *) (case f (INR (INR (INL (i - 1)))) of | Fail e => Fail e | Diverge => Diverge | Return r => (* Eliminate the unwanted results *) case r of | INL _ => Fail Failure | INR (INL _) => Fail Failure | INR (INR (INL _)) => Fail Failure | INR (INR (INR b)) => (* Extract the output of odd *) (* Inject into the variant for the output of even *) Return (INR (INL b)) ) | INR (INL _) => (* Output of even *) (* We must ignore this one *) Fail Failure | INR (INR (INL i)) => (* Body of odd *) if i = 0 then Return (INR (INR (INR F))) else (* Call to even *) (case f (INL (i - 1)) of | Fail e => Fail e | Diverge => Diverge | Return r => (* Eliminate the unwanted results *) case r of | INL _ => Fail Failure | INR (INL b) => (* Extract the output of even *) (* Inject into the variant for the output of odd *) Return (INR (INR (INR b))) | INR (INR (INL _)) => Fail Failure | INR (INR (INR _)) => Fail Failure ) | INR (INR (INR _)) => (* Output of odd *) (* We must ignore this one *) Fail Failure ’ Theorem even_odd_body_is_valid_aux: is_valid_fp_body (SUC (SUC n)) even_odd_body Proof prove_body_is_valid_tac (SOME even_odd_body_def) QED Theorem even_odd_body_is_valid: is_valid_fp_body (SUC (SUC 0)) even_odd_body Proof irule even_odd_body_is_valid_aux QED val even_raw_def = Define ‘ even (i : int) = case fix even_odd_body (INL i) of | Fail e => Fail e | Diverge => Diverge | Return r => case r of | INL _ => Fail Failure | INR (INL b) => Return b | INR (INR (INL _)) => Fail Failure | INR (INR (INR _)) => Fail Failure ’ val odd_raw_def = Define ‘ odd (i : int) = case fix even_odd_body (INR (INR (INL i))) of | Fail e => Fail e | Diverge => Diverge | Return r => case r of | INL _ => Fail Failure | INR (INL b) => Fail Failure | INR (INR (INL _)) => Fail Failure | INR (INR (INR b)) => Return b ’ Theorem even_def: ∀i. even (i : int) : bool result = if i = 0 then Return T else odd (i - 1) Proof prove_def_eq_tac even_raw_def [even_raw_def, odd_raw_def] even_odd_body_is_valid even_odd_body_def QED Theorem odd_def: ∀i. odd (i : int) : bool result = if i = 0 then Return F else even (i - 1) Proof prove_def_eq_tac odd_raw_def [even_raw_def, odd_raw_def] even_odd_body_is_valid even_odd_body_def QED val _ = export_theory ()