diff options
Diffstat (limited to '')
-rw-r--r-- | compiler/Extract.ml | 151 |
1 files changed, 136 insertions, 15 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 7670c753..f45b9b58 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -489,6 +489,16 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) fname ^ lp_suffix ^ suffix in + let terminates_clause_name (_fid : A.FunDeclId.id) (fname : fun_name) + (num_loops : int) (loop_id : LoopId.id option) : string = + let fname = fun_name_to_snake_case fname in + let lp_suffix = default_fun_loop_suffix num_loops loop_id in + (* Compute the suffix *) + let suffix = "_terminates" in + (* Concatenate *) + fname ^ lp_suffix ^ suffix + in + let var_basename (_varset : StringSet.t) (basename : string option) (ty : ty) : string = (* If there is a basename, we use it *) @@ -619,6 +629,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) global_name; fun_name; decreases_clause_name; + terminates_clause_name; var_basename; type_var_basename; append_index; @@ -1348,7 +1359,11 @@ let extract_fun_decl_register_names (ctx : extraction_ctx) (keep_fwd : bool) let (fwd, loop_fwds), back_ls = def in (* Register the decrease clauses, if necessary *) let register_decreases ctx def = - if has_decreases_clause def then ctx_add_decreases_clause def ctx else ctx + if has_decreases_clause def then + let ctx = ctx_add_decreases_clause def ctx in + ctx_add_terminates_clause def ctx + else + ctx in let ctx = List.fold_left register_decreases ctx (fwd :: loop_fwds) in (* Register the function names *) @@ -1626,31 +1641,40 @@ and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) in let is_lean_struct = !backend = Lean && adt_cons.variant_id = None in if is_lean_struct then - (* TODO: enclosing curly brace is indented too far to the right *) (* TODO: when only one or two fields differ, considering using the with syntax (peephole optimization) *) let decl_id = match adt_cons.adt_id with | AdtId id -> id | _ -> assert false in let def_kind = (TypeDeclId.Map.find decl_id ctx.trans_ctx.type_context.type_decls).kind in let fields = match def_kind with | T.Struct fields -> fields | _ -> assert false in let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in - F.pp_open_vbox fmt ctx.indent_incr; + F.pp_open_hvbox fmt 0; + F.pp_open_hvbox fmt ctx.indent_incr; F.pp_print_string fmt "{"; F.pp_print_space fmt (); + F.pp_open_hvbox fmt ctx.indent_incr; + F.pp_open_hvbox fmt 0; Collections.List.iter_link (fun () -> F.pp_print_string fmt ","; + F.pp_close_box fmt (); F.pp_print_space fmt () ) (fun ((fid, _), e) -> + F.pp_open_hovbox fmt 0; let f = ctx_get_field adt_cons.adt_id fid ctx in F.pp_print_string fmt f; F.pp_print_string fmt " := "; - extract_texpression ctx fmt true e + F.pp_open_hvbox fmt ctx.indent_incr; + extract_texpression ctx fmt true e; + F.pp_close_box fmt () ) (List.combine fields args); - F.pp_print_space fmt (); F.pp_close_box fmt (); - F.pp_print_string fmt "}"; + F.pp_close_box fmt (); + F.pp_close_box fmt (); + F.pp_close_box fmt (); + F.pp_print_space fmt (); + F.pp_print_string fmt "}" else let use_parentheses = inside && args <> [] in if use_parentheses then F.pp_print_string fmt "("; @@ -2015,7 +2039,7 @@ let assert_backend_supports_decreases_clauses () = *) let extract_template_decreases_clause (ctx : extraction_ctx) (fmt : F.formatter) (def : fun_decl) : unit = - assert_backend_supports_decreases_clauses (); + assert (!backend = FStar); (* Retrieve the function name *) let def_name = ctx_get_decreases_clause def.def_id def.loop_id ctx in @@ -2028,16 +2052,14 @@ let extract_template_decreases_clause (ctx : extraction_ctx) (fmt : F.formatter) * one line *) F.pp_open_hvbox fmt 0; (* Add the [unfold] keyword *) - if !backend = FStar then begin - F.pp_print_string fmt "unfold"; - F.pp_print_space fmt (); - end; + F.pp_print_string fmt "unfold"; + F.pp_print_space fmt (); (* Open a box for "let FUN_NAME (PARAMS) : EFFECT = admit()" *) F.pp_open_hvbox fmt ctx.indent_incr; (* Open a box for "let FUN_NAME (PARAMS) : EFFECT =" *) F.pp_open_hovbox fmt ctx.indent_incr; (* > "let FUN_NAME" *) - F.pp_print_string fmt ((if !backend = FStar then "let " else "def ") ^ def_name); + F.pp_print_string fmt ("let " ^ def_name); F.pp_print_space fmt (); (* Extract the parameters *) let space = ref true in @@ -2046,20 +2068,119 @@ let extract_template_decreases_clause (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt ":"; (* Print the signature *) F.pp_print_space fmt (); - F.pp_print_string fmt (if !backend = FStar then "nat" else "Nat"); + F.pp_print_string fmt "nat"; (* Print the "=" *) F.pp_print_space fmt (); - F.pp_print_string fmt (if !backend = FStar then "=" else ":="); + F.pp_print_string fmt "="; (* Close the box for "let FUN_NAME (PARAMS) : EFFECT =" *) F.pp_close_box fmt (); F.pp_print_space fmt (); (* Print the "admit ()" *) - F.pp_print_string fmt (if !backend = FStar then "admit ()" else "sorry"); + F.pp_print_string fmt "admit ()"; + (* Close the box for "let FUN_NAME (PARAMS) : EFFECT = admit()" *) + F.pp_close_box fmt (); + (* Close the box for the whole definition *) + F.pp_close_box fmt (); + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0 + +(** Extract templates for the termination_by and decreases_by clauses of a + recursive function definition. + + For Lean only. + + We extract two commands. The first one is a regular definition for the + termination measure (the value derived from the function arguments that + decreases over function calls). The second one is a macro definition that + defines a proof script (allowed to refer to function arguments) that proves + termination. +*) +let extract_termination_and_decreasing (ctx: extraction_ctx) (fmt: F.formatter) (def: fun_decl): unit = + assert (!backend = Lean); + + (* Retrieve the function name *) + let def_name = ctx_get_terminates_clause def.def_id def.loop_id ctx in + let def_body = Option.get def.body in + (* Add a break before *) + F.pp_print_break fmt 0 0; + (* Print a comment to link the extracted type to its original rust definition *) + extract_comment fmt ("[" ^ Print.fun_name_to_string def.basename ^ "]: termination measure"); + F.pp_print_space fmt (); + (* Open a box for the definition, so that whenever possible it gets printed on + * one line *) + F.pp_open_hvbox fmt 0; + (* Add the [unfold] keyword *) + F.pp_print_string fmt "@[simp]"; + F.pp_print_space fmt (); + (* Open a box for "let FUN_NAME (PARAMS) : EFFECT = admit()" *) + F.pp_open_hvbox fmt ctx.indent_incr; + (* Open a box for "let FUN_NAME (PARAMS) : EFFECT =" *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* > "let FUN_NAME" *) + F.pp_print_string fmt ("def " ^ def_name); + F.pp_print_space fmt (); + (* Extract the parameters *) + let space = ref true in + let _, ctx_body = extract_fun_parameters space ctx fmt def in + (* Print the ":=" *) + F.pp_print_space fmt (); + F.pp_print_string fmt ":="; + (* Close the box for "let FUN_NAME (PARAMS) : EFFECT =" *) + F.pp_close_box fmt (); + F.pp_print_space fmt (); + (* Tuple of the arguments *) + let vars = List.map (fun (v: var) -> v.id) def_body.inputs in + if List.length vars = 1 then + F.pp_print_string fmt (ctx_get_var (List.hd vars) ctx_body) + else begin + F.pp_open_hovbox fmt 0; + F.pp_print_string fmt "("; + Collections.List.iter_link + (fun () -> + F.pp_print_string fmt ","; + F.pp_print_space fmt ()) + (fun v -> F.pp_print_string fmt (ctx_get_var v ctx_body)) + vars; + F.pp_print_string fmt ")"; + F.pp_close_box fmt (); + end; (* Close the box for "let FUN_NAME (PARAMS) : EFFECT = admit()" *) F.pp_close_box fmt (); (* Close the box for the whole definition *) F.pp_close_box fmt (); (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0; + + (* Now extract a template for the termination proof *) + let def_name = ctx_get_decreases_clause def.def_id def.loop_id ctx in + (* syntax <def_name> term ... term : tactic *) + F.pp_print_break fmt 0 0; + F.pp_open_hvbox fmt 0; + F.pp_print_string fmt "syntax \""; + F.pp_print_string fmt def_name; + F.pp_print_string fmt "\" term+ : tactic"; + F.pp_print_break fmt 0 0; + F.pp_print_break fmt 0 0; + (* macro_rules | `(tactic| fact_termination_proof $x) => `(tactic| ( *) + F.pp_print_string fmt "macro_rules"; + F.pp_print_space fmt (); + F.pp_open_hovbox fmt ctx.indent_incr; + F.pp_open_hovbox fmt 0; + F.pp_print_string fmt "| `(tactic| "; + F.pp_print_string fmt def_name; + F.pp_print_space fmt (); + Collections.List.iter_link (F.pp_print_space fmt) + (fun v -> + F.pp_print_string fmt "$"; + F.pp_print_string fmt (ctx_get_var v ctx_body)) + vars; + F.pp_print_string fmt ") =>"; + F.pp_close_box fmt (); + F.pp_open_hovbox fmt ctx.indent_incr; + F.pp_print_string fmt "`(tactic| sorry)"; + F.pp_close_box fmt (); + F.pp_close_box fmt (); + F.pp_close_box fmt (); F.pp_print_break fmt 0 0 (** Extract a function declaration. |