structure divDefLib :> divDefLib = struct open primitivesBaseTacLib primitivesLib divDefTheory val dbg = ref false fun print_dbg s = if (!dbg) then print s else () val result_ty = “:'a result” val error_ty = “:error” val alpha_ty = “:'a” val num_ty = “:num” val zero_num_tm = “0:num” val suc_tm = “SUC” val return_tm = “Return : 'a -> 'a result” val fail_tm = “Fail : error -> 'a result” val fail_failure_tm = “Fail Failure : 'a result” val diverge_tm = “Diverge : 'a result” (* Switch to use ‘fix_exec’ (leading to executable definitions) and ‘fix’ (non executable) *) val use_fix_exec = ref true val fix_tm = “fix” val fix_exec_tm = “fix_exec” val is_valid_fp_body_tm = “is_valid_fp_body” fun mk_result (ty : hol_type) : hol_type = Type.type_subst [ alpha_ty |-> ty ] result_ty fun dest_result (ty : hol_type) : hol_type = let val {Args=out_ty, Thy=thy, Tyop=tyop} = dest_thy_type ty in if thy = "primitives" andalso tyop = "result" then hd out_ty else failwith "dest_result: not a result" end fun mk_return (x : term) : term = mk_icomb (return_tm, x) fun mk_fail (ty : hol_type) (e : term) : term = mk_comb (inst [ alpha_ty |-> ty ] fail_tm, e) fun mk_fail_failure (ty : hol_type) : term = inst [ alpha_ty |-> ty ] fail_failure_tm fun mk_diverge (ty : hol_type) : term = inst [ alpha_ty |-> ty ] diverge_tm fun mk_suc (n : term) = mk_comb (suc_tm, n) fun enumerate (ls : 'a list) : (int * 'a) list = zip (List.tabulate (List.length ls, fn i => i)) ls (*=============================================================================* * * Generate the (non-recursive) body to give to the fixed-point operator * * ============================================================================*) (* Small helper to generate wrappers of the shape: ‘INL x’, ‘INR (INL x)’, etc. Note that we should have: ‘length before_tys + 1 + length after tys >= 2’ Ex.: ==== The enumeration has type: “: 'a + 'b + 'c + 'd”. We want to generate the variant which injects “x:'c” into this enumeration. We need to split the list of types into: {[ before_tys = [“:'a”, “'b”] tm = “x: 'c” after_tys = [“:'d”] ]} The function will generate: {[ INR (INR (INL x) : 'a + 'b + 'c + 'd ]} (* Debug *) val before_tys = [“:'a”, “:'b”, “:'c”] val tm = “x:'d” val after_tys = [“:'e”, “:'f”] val before_tys = [“:'a”, “:'b”, “:'c”] val tm = “x:'d” val after_tys = [] mk_inl_inr_wrapper before_tys tm after_tys *) fun list_mk_inl_inr (before_tys : hol_type list) (tm : term) (after_tys : hol_type list) : term = let val (before_tys, pat) = if after_tys = [] then let val just_before_ty = List.last before_tys val before_tys = List.take (before_tys, List.length before_tys - 1) val pat = sumSyntax.mk_inr (tm, just_before_ty) in (before_tys, pat) end else (before_tys, sumSyntax.mk_inl (tm, sumSyntax.list_mk_sum after_tys)) val pat = foldr (fn (ty, pat) => sumSyntax.mk_inr (pat, ty)) pat before_tys in pat end (* This function wraps a term into the proper variant of the input/output sum. Ex.: ==== For the input of the first function, we generate: ‘INL x’ For the output of the first function, we generate: ‘INR (INL x)’ For the input of the 2nd function, we generate: ‘INR (INR (INL x))’ etc. If ‘is_input’ is true: we are wrapping an input. Otherwise we are wrapping an output. (* Debug *) val tys = [(“:'a”, “:'b”), (“:'c”, “:'d”), (“:'e”, “:'f”)] val j = 1 val tm = “x:'c” val tm = “y:'d” val is_input = true *) fun inject_in_param_sum (tys : (hol_type * hol_type) list) (j : int) (is_input : bool) (tm : term) : term = let fun flatten ls = List.concat (map (fn (x, y) => [x, y]) ls) val before_tys = flatten (List.take (tys, j)) val (input_ty, output_ty) = List.nth (tys, j) val after_tys = flatten (List.drop (tys, j + 1)) val (before_tys, after_tys) = if is_input then (before_tys, output_ty :: after_tys) else (before_tys @ [input_ty], after_tys) in list_mk_inl_inr before_tys tm after_tys end (* Remark: the order of the branches when creating matches is important. For instance, in the case of ‘result’ it must be: ‘Return’, ‘Fail’, ‘Diverge’. For the purpose of stability and maintainability, we introduce this small helper which reorders the cases in a pattern before actually creating the case expression. *) fun unordered_mk_case (scrut: term, pats: (term * term) list) : term = let (* Retrieve the constructors *) val cl = TypeBase.constructors_of (type_of scrut) (* Retrieve the names of the constructors *) val names = map (fst o dest_const) cl (* Use those to reorder the patterns *) fun is_pat (name : string) (pat, _) = let val app = (fst o strip_comb) pat val app_name = (fst o dest_const) app in app_name = name end val pats = map (fn name => valOf (List.find (is_pat name) pats)) names in (* Create the case *) TypeBase.mk_case (scrut, pats) end (* Wrap a term of type “:'a result” into a ‘case of’ which matches over the result. Ex.: ==== {[ f x ~~> case f x of | Fail e => Fail e | Diverge => Diverge | Return y => ... (* The branch content is generated by the continuation *) ]} ‘gen_ret_branch’ is a *continuation* which generates the content of the ‘Return’ branch (i.e., the content of the ‘...’ in the example above). It receives as input the value contained by the ‘Return’ (i.e., the variable ‘y’ in the example above). Remark.: the type of the term generated by ‘gen_ret_branch’ must have the type ‘result’, but it can change the content of the result (i.e., if ‘scrut’ has type ‘:'a result’, we can change the type of the wrapped expression to ‘:'b result’). (* Debug *) val scrut = “x: int result” fun gen_ret_branch x = mk_return x val scrut = “x: int result” fun gen_ret_branch _ = “Return T” mk_result_case scrut gen_ret_branch *) fun mk_result_case (scrut : term) (gen_ret_branch : term -> term) : term = let val scrut_ty = dest_result (type_of scrut) (* Return branch *) val ret_var = genvar scrut_ty val ret_pat = mk_return ret_var val ret_br = gen_ret_branch ret_var val ret_ty = dest_result (type_of ret_br) (* Failure branch *) val fail_var = genvar error_ty val fail_pat = mk_fail scrut_ty fail_var val fail_br = mk_fail ret_ty fail_var (* Diverge branch *) val div_pat = mk_diverge scrut_ty val div_br = mk_diverge ret_ty in unordered_mk_case (scrut, [(ret_pat, ret_br), (fail_pat, fail_br), (div_pat, div_br)]) end (* Generate a ‘case ... of’ over a sum type. Ex.: ==== If the scrutinee is: “x : 'a + 'b + 'c” (i.e., the tys list is: [“:'a”, “:b”, “:c”]), we generate: {[ case x of | INL y0 => ... (* Branch of index 0 *) | INR (INL y1) => ... (* Branch of index 1 *) | INR (INR (INL y2)) => ... (* Branch of index 2 *) | INR (INR (INR y3)) => ... (* Branch of index 3 *) ]} The content of the branches is generated by the ‘gen_branch’ continuation, which receives as input the index of the branch as well as the variable introduced by the pattern (in the example above: ‘y0’ for the branch 0, ‘y1’ for the branch 1, etc.) (* Debug *) val tys = [“:'a”, “:'b”] val scrut = mk_var ("x", sumSyntax.list_mk_sum tys) fun gen_branch i (x : term) = “F” val tys = [“:'a”, “:'b”, “:'c”, “:'d”] val scrut = mk_var ("x", sumSyntax.list_mk_sum tys) fun gen_branch i (x : term) = if type_of x = “:'c” then mk_return x else mk_fail_failure “:'c” list_mk_sum_case scrut tys gen_branch *) (* For debugging *) val list_mk_sum_case_case = ref (“T”, [] : (term * term) list) (* val (scrut, [(pat1, br1), (pat2, br2)]) = !list_mk_sum_case_case *) fun list_mk_sum_case (scrut : term) (tys : hol_type list) (gen_branch : int -> term -> term) : term = let (* Create the cases. Note that without sugar, the match actually looks like this: {[ case x of | INL y0 => ... (* Branch of index 0 *) | INR x1 case x1 of | INL y1 => ... (* Branch of index 1 *) | INR x2 => case x2 of | INL y2 => ... (* Branch of index 2 *) | INR y3 => ... (* Branch of index 3 *) ]} *) fun create_case (j : int) (scrut : term) (tys : hol_type list) : term = let val _ = print_dbg ("list_mk_sum_case: " ^ String.concatWith ", " (map type_to_string tys) ^ "\n") in case tys of [] => failwith "tys is too short" | [ ty ] => (* Last element: no match to perform *) gen_branch j scrut | ty1 :: tys => (* Not last: we create a pattern: {[ case scrut of | INL pat_var1 => ... (* Branch of index i *) | INR pat_var2 => ... (* Generate this term recursively *) ]} *) let (* INL branch *) val after_ty = sumSyntax.list_mk_sum tys val pat_var1 = genvar ty1 val pat1 = sumSyntax.mk_inl (pat_var1, after_ty) val br1 = gen_branch j pat_var1 (* INR branch *) val pat_var2 = genvar after_ty val pat2 = sumSyntax.mk_inr (pat_var2, ty1) val br2 = create_case (j+1) pat_var2 tys val _ = print_dbg ("list_mk_sum_case: assembling:\n" ^ term_to_string scrut ^ ",\n" ^ "[(" ^ term_to_string pat1 ^ ",\n " ^ term_to_string br1 ^ "),\n\n" ^ " (" ^ term_to_string pat2 ^ ",\n " ^ term_to_string br2 ^ ")]\n\n") val case_elems = (scrut, [(pat1, br1), (pat2, br2)]) val _ = list_mk_sum_case_case := case_elems in (* Put everything together *) TypeBase.mk_case case_elems end end in create_case 0 scrut tys end (* Generate a ‘case ... of’ to select the input/output of the ith variant of the param enumeration. Ex.: ==== There are two functions in the group, and we select the input of the function of index 1: {[ case x of | INL _ => Fail Failure (* Input of function of index 0 *) | INR (INL _) => Fail Failure (* Output of function of index 0 *) | INR (INR (INL y)) => Return y (* Input of the function of index 1: select this one *) | INR (INR (INR _)) => Fail Failure (* Output of the function of index 1 *) ]} (* Debug *) val tys = [(“:'a”, “:'b”)] val scrut = “x : 'a + 'b” val fi = 0 val is_input = true val tys = [(“:'a”, “:'b”), (“:'c”, “:'d”)] val scrut = “x : 'a + 'b + 'c + 'd” val fi = 1 val is_input = false val scrut = mk_var ("x", sumSyntax.list_mk_sum (flatten tys)) list_mk_case_select scrut tys fi is_input *) fun list_mk_case_sum_select (scrut : term) (tys : (hol_type * hol_type) list) (fi : int) (is_input : bool) : term = let (* The index of the element in the enumeration that we will select *) val i = 2 * fi + (if is_input then 0 else 1) (* Flatten the types and numerotate them *) fun flatten ls = List.concat (map (fn (x, y) => [x, y]) ls) val tys = flatten tys (* Get the return type *) val ret_ty = List.nth (tys, i) (* The continuation which will generate the content of the branches *) fun gen_branch j var = if j = i then mk_return var else mk_fail_failure ret_ty in (* Generate the ‘case ... of’ *) list_mk_sum_case scrut tys gen_branch end (* Generate a ‘case ... of’ to select the input/output of the ith variant of the param enumeration. Ex.: ==== There are two functions in the group, and we select the input of the function of index 1: {[ case x of | Fail e => Fail e | Diverge => Diverge | Return r => case r of | INL _ => Fail Failure (* Input of function of index 0 *) | INR (INL _) => Fail Failure (* Output of function of index 0 *) | INR (INR (INL y)) => Return y (* Input of the function of index 1: select this one *) | INR (INR (INR _)) => Fail Failure (* Output of the function of index 1 *) ]} *) fun mk_case_select_result_sum (scrut : term) (tys : (hol_type * hol_type) list) (fi : int) (is_input : bool) : term = (* We match over the result, then over the enumeration *) mk_result_case scrut (fn x => list_mk_case_sum_select x tys fi is_input) (* Generate a body for the fixed-point operator from a quoted group of mutually recursive definitions. See TODO for detailed explanations: from the quoted equations for ‘nth’ (or for [‘even’, ‘odd’]) we generate the body ‘nth_body’ (or ‘even_odd_body’, respectively). *) fun mk_body (fnames : string list) (in_out_tys : (hol_type * hol_type) list) (def_tms : term list) : term = let val fnames_set = Redblackset.fromList String.compare fnames (* Compute a map from function name to function index *) val fnames_map = Redblackmap.fromList String.compare (map (fn (x, y) => (y, x)) (enumerate fnames)) (* Compute the input/output type, that we dub the "parameter type" *) fun flatten ls = List.concat (map (fn (x, y) => [x, y]) ls) val param_type = sumSyntax.list_mk_sum (flatten in_out_tys) (* Introduce a variable for the confinuation *) val fcont = genvar (param_type --> mk_result param_type) (* In the function equations, replace all the recursive calls with calls to the continuation. When replacing a recursive call, we have to do two things: - we need to inject the input parameters into the parameter type Ex.: - ‘nth tl i’ becomes ‘f (INL (tl, i))’ where ‘f’ is the continuation - ‘even i’ becomes ‘f (INL i)’ where ‘f’ is the continuation - we need to wrap the the call to the continuation into a ‘case ... of’ to extract its output (we need to make sure that the transformation preserves the type of the expression!) Ex.: ‘nth tl i’ becomes: {[ case f (INL (tl, i)) of | Fail e => Fail e | Diverge => Diverge | Return r => case r of | INL _ => Fail Failure | INR x => Return (INR x) ]} *) (* For debugging *) val replace_rec_calls_rec_call_tm = ref “T” fun replace_rec_calls (fnames_set : string Redblackset.set) (tm : term) : term = let val _ = print_dbg ("replace_rec_calls: original expression:\n" ^ term_to_string tm ^ "\n\n") val ntm = case dest_term tm of VAR (name, ty) => (* Check that this is not one of the functions in the group - remark: we could handle that by introducing lambdas. *) if Redblackset.member (fnames_set, name) then failwith ("mk_body: not well-formed definition: found " ^ name ^ " in an improper position") else tm | CONST _ => tm | LAMB (x, tm) => let (* The variable might shadow one of the functions: remove it from the set of function names - remark: Redblackset.delete raises [NotFound] if the value is not present in the set *) val varname = (fst o dest_var) x val fnames_set = if Redblackset.member (fnames_set, varname) then Redblackset.delete (fnames_set, varname) else fnames_set (* Update the term in the lambda *) val tm = replace_rec_calls fnames_set tm in (* Reconstruct *) mk_abs (x, tm) end | COMB (_, _) => let (* Completely destruct the application, check if this is a recursive call *) val (app, args) = strip_comb tm val is_rec_call = Redblackset.member (fnames_set, (fst o dest_var) app) handle HOL_ERR _ => false (* Whatever the case, apply the transformation to all the inputs *) val args = map (replace_rec_calls fnames_set) args in (* If this is not a recursive call: apply the transformation to all the terms. Otherwise, replace. *) if not is_rec_call then list_mk_comb (replace_rec_calls fnames_set app, args) else (* Rec call: replace *) let val _ = print_dbg ("replace_rec_calls: rec call\n\n") val _ = replace_rec_calls_rec_call_tm := tm (* First, find the index of the function *) val fname = (fst o dest_var) app val fi = Redblackmap.find (fnames_map, fname) (* Inject the input values into the param type *) val input = pairSyntax.list_mk_pair args val input = inject_in_param_sum in_out_tys fi true input (* Create the recursive call *) val call = mk_comb (fcont, input) (* Wrap the call into a ‘case ... of’ to extract the output *) val call = mk_case_select_result_sum call in_out_tys fi false in (* Return *) call end end val _ = print_dbg ("replace_rec_calls: new expression:\n" ^ term_to_string ntm ^ "\n\n") in ntm end handle HOL_ERR e => let val _ = print_dbg ("replace_rec_calls: failed on:\n" ^ term_to_string tm ^ "\n\n") in raise (HOL_ERR e) end fun replace_rec_calls_in_eq (eq : term) : term = let val (l, r) = dest_eq eq in mk_eq (l, replace_rec_calls fnames_set r) end val def_tms_with_fcont = map replace_rec_calls_in_eq def_tms (* Wrap all the function bodies to inject their result into the param type. We collect the function inputs at the same time, because they will be grouped into a tuple that we will have to deconstruct. *) fun inject_body_to_enums (i : int, def_eq : term) : term list * term = let val (l, body) = dest_eq def_eq val (_, args) = strip_comb l (* We have the deconstruct the result, then, in the ‘Return’ branch, properly wrap the returned value *) val body = mk_result_case body (fn x => mk_return (inject_in_param_sum in_out_tys i false x)) in (args, body) end val def_tms_inject = map inject_body_to_enums (enumerate def_tms_with_fcont) (* Currify the body inputs. For instance, if the body has inputs: ‘x’, ‘y’; we return the following: {[ (‘z’, ‘case z of (x, y) => ... (* body *) ’) ]} where ‘z’ is fresh. We return: (curried input, body). (* Debug *) val body = “(x:'a, y:'b, z:'c)” val args = [“x:'a”, “y:'b”, “z:'c”] currify_body_inputs (args, body) *) fun currify_body_inputs (args : term list, body : term) : term * term = let fun mk_curry (args : term list) (body : term) : term * term = case args of [] => failwith "no inputs" | [x] => (x, body) | x1 :: args => let val (x2, body) = mk_curry args body val scrut = genvar (pairSyntax.list_mk_prod (map type_of (x1 :: args))) val pat = pairSyntax.mk_pair (x1, x2) val br = body in (scrut, TypeBase.mk_case (scrut, [(pat, br)])) end in mk_curry args body end val def_tms_currified = map currify_body_inputs def_tms_inject (* Group all the functions into a single body, with an outer ‘case .. of’ which selects the appropriate body depending on the input *) val param_ty = sumSyntax.list_mk_sum (flatten in_out_tys) val input = genvar param_ty fun mk_mut_rec_body_branch (i : int) (patvar : term) : term = (* Case disjunction on whether the branch is for an input value (in which case we call the proper body) or an output value (in which case we return ‘Fail ...’ *) if i mod 2 = 0 then let val fi = i div 2 val (x, def_tm) = List.nth (def_tms_currified, fi) (* The variable in the pattern and the variable expected by the body may not be the same: we introduce a let binding *) val def_tm = mk_let (mk_abs (x, def_tm), patvar) in def_tm end else (* Output value: fail *) mk_fail_failure param_ty val mut_rec_body = list_mk_sum_case input (flatten in_out_tys) mk_mut_rec_body_branch (* Abstract away the parameters to produce the final body of the fixed point *) val mut_rec_body = list_mk_abs ([fcont, input], mut_rec_body) in mut_rec_body end (*=============================================================================* * * Prove that the body satisfies the validity condition * * ============================================================================*) (* Tactic to prove that a body is valid: perform one step. *) fun prove_body_is_valid_tac_step (asms, g) = let (* The goal has the shape: {[ (∀g h. ... g x = ... h x) ∨ ∃h y. is_valid_fp_body n h ∧ ∀g. ... g x = ... od ]} *) (* Retrieve the scrutinee in the goal (‘x’) *) val body = (lhs o snd o strip_forall o fst o dest_disj) g val scrut = strip_all_cases_get_scrutinee_or_curried body (* Retrieve the first quantified continuations from the goal (‘g’) *) val qc = (hd o fst o strip_forall o fst o dest_disj) g (* Check if the scrutinee is a recursive call *) val (scrut_app, _) = strip_comb scrut val _ = print_dbg ("prove_body_is_valid_step: Scrutinee: " ^ term_to_string scrut ^ "\n") (* For the recursive calls: *) fun step_rec () = let val _ = print_dbg ("prove_body_is_valid_step: rec call\n") (* We need to instantiate the ‘h’ existantially quantified function *) (* First, retrieve the body of the function: it is given by the ‘Return’ branch *) val (_, _, branches) = TypeBase.dest_case body (* Find the branch corresponding to the return *) val ret_branch = List.find (fn (pat, _) => let val {Name=name, Thy=thy, Ty = _ } = (dest_thy_const o fst o strip_comb) pat in thy = "primitives" andalso name = "Return" end) branches val var = (hd o snd o strip_comb o fst o valOf) ret_branch val br = (snd o valOf) ret_branch (* Abstract away the input variable introduced by the pattern and the continuation ‘g’ *) val h = list_mk_abs ([qc, var], br) val _ = print_dbg ("prove_body_is_valid_step: h: " ^ term_to_string h ^ "\n") (* Retrieve the input parameter ‘x’ *) val input = (snd o dest_comb) scrut val _ = print_dbg ("prove_body_is_valid_step: y: " ^ term_to_string input ^ "\n") in ((* Choose the right possibility (this is a recursive call) *) disj2_tac >> (* Instantiate the quantifiers *) qexists ‘^h’ >> qexists ‘^input’ >> (* Unfold the predicate once *) pure_once_rewrite_tac [is_valid_fp_body_def] >> (* We have two subgoals: - we have to prove that ‘h’ is valid - we have to finish the proof of validity for the current body *) conj_tac >> fs [case_result_switch_eq, bind_def] >> (* The first subgoal should have been eliminated *) gen_tac) end in (* If recursive call: special treatment. Otherwise, we do a simple disjunction *) (if term_eq scrut_app qc then step_rec () else (Cases_on ‘^scrut’ >> fs [case_result_switch_eq])) (asms, g) end (* Tactic to prove that a body is valid *) fun prove_body_is_valid_tac (body_def : thm option) : tactic = let val body_def_thm = case body_def of SOME th => [th] | NONE => [] in pure_once_rewrite_tac [is_valid_fp_body_def] >> gen_tac >> (* Expand *) fs body_def_thm >> fs [bind_def, case_result_switch_eq] >> (* Explore the body *) rpt prove_body_is_valid_tac_step end (* Prove that a body satisfies the validity condition of the fixed point *) fun prove_body_is_valid (body : term) : thm = let (* Explore the body and count the number of occurrences of recursive calls so that we can properly instantiate the ‘N’ argument of ‘is_valid_fp_body’ (note: we compute an overapproximation). We first retrieve the name of the continuation parameter. Rem.: we generated fresh names so that, for instance, the continuation name doesn't collide with other names. Because of this, we don't need to look for collisions when exploring the body (and in the worst case, we would cound an overapproximation of the number of recursive calls). *) val fcont = (hd o fst o strip_abs) body val fcont_name = (fst o dest_var) fcont fun count_body_rec_calls (body : term) : int = case dest_term body of VAR (name, _) => if name = fcont_name then 1 else 0 | CONST _ => 0 | COMB (x, y) => count_body_rec_calls x + count_body_rec_calls y | LAMB (_, x) => count_body_rec_calls x val num_rec_calls = count_body_rec_calls body (* Generate the term ‘SUC (SUC ... (SUC n))’ where ‘n’ is a fresh variable. Remark: we first prove ‘is_valid_fp_body (SUC ... n) body’ then substitue ‘n’ with ‘0’ to prevent the quantity from being rewritten to a bit representation, which would prevent unfolding of the ‘is_valid_fp_body’. *) val nvar = genvar num_ty (* Rem.: we stack num_rec_calls + 1 occurrences of ‘SUC’ (and the + 1 is important) *) fun mk_n i = if i = 0 then mk_suc nvar else mk_suc (mk_n (i-1)) val n_tm = mk_n num_rec_calls (* Generate the lemma statement *) val is_valid_tm = list_mk_icomb (is_valid_fp_body_tm, [n_tm, body]) val is_valid_thm = prove (is_valid_tm, prove_body_is_valid_tac NONE) (* Replace ‘nvar’ with ‘0’ *) val is_valid_thm = INST [nvar |-> zero_num_tm] is_valid_thm in is_valid_thm end (*=============================================================================* * * Generate the definitions with the fixed-point operator * * ============================================================================*) (* Generate the raw definitions by using the grouped definition body and the fixed point operator *) fun mk_raw_defs (in_out_tys : (hol_type * hol_type) list) (def_tms : term list) (body_is_valid : thm) : thm list = let (* Retrieve the body *) val body = (List.last o snd o strip_comb o concl) body_is_valid (* Create the term ‘fix_exec body’ *) val fixed_body = mk_icomb (if !use_fix_exec then fix_exec_tm else fix_tm, body) (* For every function in the group, generate the equation that we will use as definition. In particular: - add the properly injected input ‘x’ to ‘fix body’ (ex.: for ‘nth ls i’ we add the input ‘INL (ls, i)’) - wrap ‘fix body x’ into a case disjunction to extract the relevant output For instance, in the case of ‘nth ls i’: {[ nth (ls : 't list_t) (i : u32) = case fix nth_body (INL (ls, i)) of | Fail e => Fail e | Diverge => Diverge | Return r => case r of | INL _ => Fail Failure | INR x => Return x ]} *) fun mk_def_eq (i : int, def_tm : term) : term = let (* Retrieve the lhs of the original definition equation, and in particular the inputs *) val def_lhs = lhs def_tm val args = (snd o strip_comb) def_lhs (* Inject the inputs into the param type *) val input = pairSyntax.list_mk_pair args val input = inject_in_param_sum in_out_tys i true input (* Compose*) val def_rhs = mk_comb (fixed_body, input) (* Wrap in the case disjunction *) val def_rhs = mk_case_select_result_sum def_rhs in_out_tys i false (* Create the equation *) val def_eq_tm = mk_eq (def_lhs, def_rhs) in def_eq_tm end val raw_def_tms = map mk_def_eq (enumerate def_tms) (* Generate the definitions *) val raw_defs = map (fn tm => Define ‘^tm’) raw_def_tms in raw_defs end (*=============================================================================* * * Prove that the definitions satisfy the target equations * * ============================================================================*) (* Tactic which makes progress in a proof by making a case disjunction (we use this to explore all the paths in a function body). *) fun case_progress (asms, g) = let val scrut = (strip_all_cases_get_scrutinee o lhs) g in Cases_on ‘^scrut’ (asms, g) end (* Prove the final equation, that we will use as definition. *) fun prove_def_eq_tac (current_raw_def : thm) (all_raw_defs : thm list) (is_valid : thm) (body_def : thm option) : tactic = let val body_def_thm = case body_def of SOME th => [th] | NONE => [] val fix_eq = if !use_fix_exec then fix_exec_fixed_eq else fix_fixed_eq in rpt gen_tac >> (* Expand the definition *) pure_once_rewrite_tac [current_raw_def] >> (* Use the fixed-point equality *) pure_once_rewrite_left_tac [HO_MATCH_MP fix_eq is_valid] >> (* Expand the body definition *) pure_rewrite_tac body_def_thm >> (* Expand all the definitions from the group *) pure_rewrite_tac all_raw_defs >> (* Explore all the paths - maybe we can be smarter, but this is fast and really easy *) fs [bind_def] >> rpt (case_progress >> fs []) end (* Prove the final equations that we will give to the user as definitions *) fun prove_def_eqs (body_is_valid : thm) (def_tms : term list) (raw_defs : thm list) : thm list= let val defs_tgt_raw = zip def_tms raw_defs (* Substitute the function variables with the constants introduced in the raw definitions *) fun compute_fsubst (def_tm, raw_def) : {redex: term, residue: term} = let val (fvar, _) = (strip_comb o lhs) def_tm val fconst = (fst o strip_comb o lhs o snd o strip_forall o concl) raw_def in (fvar |-> fconst) end val fsubst = map compute_fsubst defs_tgt_raw val defs_tgt_raw = map (fn (x, y) => (subst fsubst x, y)) defs_tgt_raw fun prove_def_eq (def_tm, raw_def) : thm = let (* Quantify the parameters *) val (_, params) = (strip_comb o lhs) def_tm val def_eq_tm = list_mk_forall (params, def_tm) (* Prove *) val def_eq = prove (def_eq_tm, prove_def_eq_tac raw_def raw_defs body_is_valid NONE) in def_eq end val def_eqs = map prove_def_eq defs_tgt_raw in def_eqs end (*=============================================================================* * * The final DefineDiv function * * ============================================================================*) type absyn = Absyn.absyn (* Helper: convert an absyn to a vstruct (i.e., turn a "standard" term into a quantified term; we use it to transform function arguments into abstracted terms (in a lambda) *) fun absyn_to_vstruct (x : absyn) : Absyn.vstruct = case x of Absyn.AQ (l, t) => Absyn.VAQ (l, t) | Absyn.IDENT (l, s) => Absyn.VIDENT (l, s) | Absyn.QIDENT _ => raise (mk_HOL_ERR "divDefLib" "absyn_to_vstruct" "Unsupported: QIDENT") | Absyn.APP _ => raise (mk_HOL_ERR "divDefLib" "absyn_to_vstruct" "Unsupported: APP") | Absyn.LAM _ => raise (mk_HOL_ERR "divDefLib" "absyn_to_vstruct" "Unsupported: LAM") | Absyn.TYPED (l, y, ty) => Absyn.VTYPED (l, absyn_to_vstruct y, ty) (* We need to parse the quotation in a specific manner. The issue is that, with mutually recursive functions, the parser sometimes gets confused if some funtions have parameters with the same name but with different types. For instance: {[ f (x : int) = ... /\ g (x : bool) = ... ]} The solution is to rewrite the equations to make lambdas appear explicitely, like so: {[ f = λ(x : int) = ... /\ g = λ(x : bool) = ... ]} We do the following: - we convert the quotation to an abstract syntax tree - transform this tree into a shape where function bodies are abstractions - parse this to a term - change the shape of the term back to the original shape (with arguments on the left of the “=”) *) fun parse_quote (defs_qt : term quotation) : term = let val def_abs = Parse.Absyn defs_qt val absl = Absyn.strip_conj def_abs (* Turn an equation of the shape “f x = ...” into “f = \x. ...” *) fun make_lambda_def (def_abs : absyn) : absyn = let (* Retrieve the body *) val (app, body) = Absyn.dest_eq def_abs (* Remove the typing annotation from around the lhs, if there is, and put it around the rhs *) val (app, body) = if Absyn.is_typed app then let val (app, ty) = Absyn.dest_typed app in (app, Absyn.mk_typed (body, ty)) end else (app, body) (* Strip the arguments *) val (f, args) = Absyn.strip_app app (* Make a lambda abstraction *) val args = map absyn_to_vstruct args val body = Absyn.list_mk_lam (args, body) in Absyn.mk_eq (f, body) end val absl = map make_lambda_def absl val def_abs = Absyn.list_mk_conj absl (* Parse the quote now that it is in the proper shape *) val def_tm = (* This is taken from Defn.sml: we removed the [sort_eqns] because it is not useful in our case (its untangle the dependencies so that functions are defined before use). *) fst (Defn.parse_absyn def_abs) handle e => raise wrap_exn "divDefLib" "parse_quote" e (* Put the definition back into the original shape *) fun make_args_def (tm : term) : term = let val (f, body) = dest_eq tm val (args, body) = strip_abs body in mk_eq (list_mk_comb (f, args), body) end val def_tms = strip_conj def_tm val def_tms = map make_args_def def_tms val def_tm = list_mk_conj def_tms in def_tm end fun DefineDiv (def_qt : term quotation) = let (* Parse the definitions *) val def_tms = strip_conj (parse_quote def_qt) (* Compute the names and the input/output types of the functions *) fun compute_names_in_out_tys (tm : term) : string * (hol_type * hol_type) = let val app = lhs tm val name = (fst o dest_var o fst o strip_comb) app val out_ty = dest_result (type_of app) val input_tys = pairSyntax.list_mk_prod (map type_of ((snd o strip_comb) app)) in (name, (input_tys, out_ty)) end val (fnames, in_out_tys) = unzip (map compute_names_in_out_tys def_tms) (* Generate the body to give to the fixed-point operator *) val body = mk_body fnames in_out_tys def_tms (* Prove that the body satisfies the validity property required by the fixed point *) val body_is_valid = prove_body_is_valid body (* Generate the definitions for the various functions by using the fixed point and the body *) val raw_defs = mk_raw_defs in_out_tys def_tms body_is_valid (* Prove the final equations *) val def_eqs = prove_def_eqs body_is_valid def_tms raw_defs (* Save the final equations as definitions. *) val thm_names = map (fn x => x ^ "_def") fnames (* Because [store_definition] overrides existing names, it seems that in practice we don't really need to delete the previous definitions (we still do it: it doesn't cost much). *) val _ = List.app delete_binding thm_names val _ = map store_definition (zip thm_names def_eqs) (* Also save the custom unfoldings, for evaluation (unit tests) *) val _ = evalLib.add_unfold_thms thm_names in def_eqs end end