diff options
author | Jonathan Protzenko | 2023-01-30 18:05:21 -0800 |
---|---|---|
committer | Son HO | 2023-06-04 21:44:33 +0200 |
commit | 9804a5f28cedc79ac89d3b97ec6addb42752df3d (patch) | |
tree | 3549c94a08498578f3cfd145475891f45d4ba422 /compiler | |
parent | 1d6742c059cf53e73c9bc66cec7ac1f857830e78 (diff) |
Fix some printing bits, proper syntax for terminates and decreases clauses
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/Extract.ml | 151 | ||||
-rw-r--r-- | compiler/ExtractBase.ml | 46 | ||||
-rw-r--r-- | compiler/Translate.ml | 5 |
3 files changed, 186 insertions, 16 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 7670c753..f45b9b58 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -489,6 +489,16 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) fname ^ lp_suffix ^ suffix in + let terminates_clause_name (_fid : A.FunDeclId.id) (fname : fun_name) + (num_loops : int) (loop_id : LoopId.id option) : string = + let fname = fun_name_to_snake_case fname in + let lp_suffix = default_fun_loop_suffix num_loops loop_id in + (* Compute the suffix *) + let suffix = "_terminates" in + (* Concatenate *) + fname ^ lp_suffix ^ suffix + in + let var_basename (_varset : StringSet.t) (basename : string option) (ty : ty) : string = (* If there is a basename, we use it *) @@ -619,6 +629,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) global_name; fun_name; decreases_clause_name; + terminates_clause_name; var_basename; type_var_basename; append_index; @@ -1348,7 +1359,11 @@ let extract_fun_decl_register_names (ctx : extraction_ctx) (keep_fwd : bool) let (fwd, loop_fwds), back_ls = def in (* Register the decrease clauses, if necessary *) let register_decreases ctx def = - if has_decreases_clause def then ctx_add_decreases_clause def ctx else ctx + if has_decreases_clause def then + let ctx = ctx_add_decreases_clause def ctx in + ctx_add_terminates_clause def ctx + else + ctx in let ctx = List.fold_left register_decreases ctx (fwd :: loop_fwds) in (* Register the function names *) @@ -1626,31 +1641,40 @@ and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) in let is_lean_struct = !backend = Lean && adt_cons.variant_id = None in if is_lean_struct then - (* TODO: enclosing curly brace is indented too far to the right *) (* TODO: when only one or two fields differ, considering using the with syntax (peephole optimization) *) let decl_id = match adt_cons.adt_id with | AdtId id -> id | _ -> assert false in let def_kind = (TypeDeclId.Map.find decl_id ctx.trans_ctx.type_context.type_decls).kind in let fields = match def_kind with | T.Struct fields -> fields | _ -> assert false in let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in - F.pp_open_vbox fmt ctx.indent_incr; + F.pp_open_hvbox fmt 0; + F.pp_open_hvbox fmt ctx.indent_incr; F.pp_print_string fmt "{"; F.pp_print_space fmt (); + F.pp_open_hvbox fmt ctx.indent_incr; + F.pp_open_hvbox fmt 0; Collections.List.iter_link (fun () -> F.pp_print_string fmt ","; + F.pp_close_box fmt (); F.pp_print_space fmt () ) (fun ((fid, _), e) -> + F.pp_open_hovbox fmt 0; let f = ctx_get_field adt_cons.adt_id fid ctx in F.pp_print_string fmt f; F.pp_print_string fmt " := "; - extract_texpression ctx fmt true e + F.pp_open_hvbox fmt ctx.indent_incr; + extract_texpression ctx fmt true e; + F.pp_close_box fmt () ) (List.combine fields args); - F.pp_print_space fmt (); F.pp_close_box fmt (); - F.pp_print_string fmt "}"; + F.pp_close_box fmt (); + F.pp_close_box fmt (); + F.pp_close_box fmt (); + F.pp_print_space fmt (); + F.pp_print_string fmt "}" else let use_parentheses = inside && args <> [] in if use_parentheses then F.pp_print_string fmt "("; @@ -2015,7 +2039,7 @@ let assert_backend_supports_decreases_clauses () = *) let extract_template_decreases_clause (ctx : extraction_ctx) (fmt : F.formatter) (def : fun_decl) : unit = - assert_backend_supports_decreases_clauses (); + assert (!backend = FStar); (* Retrieve the function name *) let def_name = ctx_get_decreases_clause def.def_id def.loop_id ctx in @@ -2028,16 +2052,14 @@ let extract_template_decreases_clause (ctx : extraction_ctx) (fmt : F.formatter) * one line *) F.pp_open_hvbox fmt 0; (* Add the [unfold] keyword *) - if !backend = FStar then begin - F.pp_print_string fmt "unfold"; - F.pp_print_space fmt (); - end; + F.pp_print_string fmt "unfold"; + F.pp_print_space fmt (); (* Open a box for "let FUN_NAME (PARAMS) : EFFECT = admit()" *) F.pp_open_hvbox fmt ctx.indent_incr; (* Open a box for "let FUN_NAME (PARAMS) : EFFECT =" *) F.pp_open_hovbox fmt ctx.indent_incr; (* > "let FUN_NAME" *) - F.pp_print_string fmt ((if !backend = FStar then "let " else "def ") ^ def_name); + F.pp_print_string fmt ("let " ^ def_name); F.pp_print_space fmt (); (* Extract the parameters *) let space = ref true in @@ -2046,20 +2068,119 @@ let extract_template_decreases_clause (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt ":"; (* Print the signature *) F.pp_print_space fmt (); - F.pp_print_string fmt (if !backend = FStar then "nat" else "Nat"); + F.pp_print_string fmt "nat"; (* Print the "=" *) F.pp_print_space fmt (); - F.pp_print_string fmt (if !backend = FStar then "=" else ":="); + F.pp_print_string fmt "="; (* Close the box for "let FUN_NAME (PARAMS) : EFFECT =" *) F.pp_close_box fmt (); F.pp_print_space fmt (); (* Print the "admit ()" *) - F.pp_print_string fmt (if !backend = FStar then "admit ()" else "sorry"); + F.pp_print_string fmt "admit ()"; + (* Close the box for "let FUN_NAME (PARAMS) : EFFECT = admit()" *) + F.pp_close_box fmt (); + (* Close the box for the whole definition *) + F.pp_close_box fmt (); + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0 + +(** Extract templates for the termination_by and decreases_by clauses of a + recursive function definition. + + For Lean only. + + We extract two commands. The first one is a regular definition for the + termination measure (the value derived from the function arguments that + decreases over function calls). The second one is a macro definition that + defines a proof script (allowed to refer to function arguments) that proves + termination. +*) +let extract_termination_and_decreasing (ctx: extraction_ctx) (fmt: F.formatter) (def: fun_decl): unit = + assert (!backend = Lean); + + (* Retrieve the function name *) + let def_name = ctx_get_terminates_clause def.def_id def.loop_id ctx in + let def_body = Option.get def.body in + (* Add a break before *) + F.pp_print_break fmt 0 0; + (* Print a comment to link the extracted type to its original rust definition *) + extract_comment fmt ("[" ^ Print.fun_name_to_string def.basename ^ "]: termination measure"); + F.pp_print_space fmt (); + (* Open a box for the definition, so that whenever possible it gets printed on + * one line *) + F.pp_open_hvbox fmt 0; + (* Add the [unfold] keyword *) + F.pp_print_string fmt "@[simp]"; + F.pp_print_space fmt (); + (* Open a box for "let FUN_NAME (PARAMS) : EFFECT = admit()" *) + F.pp_open_hvbox fmt ctx.indent_incr; + (* Open a box for "let FUN_NAME (PARAMS) : EFFECT =" *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* > "let FUN_NAME" *) + F.pp_print_string fmt ("def " ^ def_name); + F.pp_print_space fmt (); + (* Extract the parameters *) + let space = ref true in + let _, ctx_body = extract_fun_parameters space ctx fmt def in + (* Print the ":=" *) + F.pp_print_space fmt (); + F.pp_print_string fmt ":="; + (* Close the box for "let FUN_NAME (PARAMS) : EFFECT =" *) + F.pp_close_box fmt (); + F.pp_print_space fmt (); + (* Tuple of the arguments *) + let vars = List.map (fun (v: var) -> v.id) def_body.inputs in + if List.length vars = 1 then + F.pp_print_string fmt (ctx_get_var (List.hd vars) ctx_body) + else begin + F.pp_open_hovbox fmt 0; + F.pp_print_string fmt "("; + Collections.List.iter_link + (fun () -> + F.pp_print_string fmt ","; + F.pp_print_space fmt ()) + (fun v -> F.pp_print_string fmt (ctx_get_var v ctx_body)) + vars; + F.pp_print_string fmt ")"; + F.pp_close_box fmt (); + end; (* Close the box for "let FUN_NAME (PARAMS) : EFFECT = admit()" *) F.pp_close_box fmt (); (* Close the box for the whole definition *) F.pp_close_box fmt (); (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0; + + (* Now extract a template for the termination proof *) + let def_name = ctx_get_decreases_clause def.def_id def.loop_id ctx in + (* syntax <def_name> term ... term : tactic *) + F.pp_print_break fmt 0 0; + F.pp_open_hvbox fmt 0; + F.pp_print_string fmt "syntax \""; + F.pp_print_string fmt def_name; + F.pp_print_string fmt "\" term+ : tactic"; + F.pp_print_break fmt 0 0; + F.pp_print_break fmt 0 0; + (* macro_rules | `(tactic| fact_termination_proof $x) => `(tactic| ( *) + F.pp_print_string fmt "macro_rules"; + F.pp_print_space fmt (); + F.pp_open_hovbox fmt ctx.indent_incr; + F.pp_open_hovbox fmt 0; + F.pp_print_string fmt "| `(tactic| "; + F.pp_print_string fmt def_name; + F.pp_print_space fmt (); + Collections.List.iter_link (F.pp_print_space fmt) + (fun v -> + F.pp_print_string fmt "$"; + F.pp_print_string fmt (ctx_get_var v ctx_body)) + vars; + F.pp_print_string fmt ") =>"; + F.pp_close_box fmt (); + F.pp_open_hovbox fmt ctx.indent_incr; + F.pp_print_string fmt "`(tactic| sorry)"; + F.pp_close_box fmt (); + F.pp_close_box fmt (); + F.pp_close_box fmt (); F.pp_print_break fmt 0 0 (** Extract a function declaration. diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 152dfc99..77170b5b 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -203,6 +203,21 @@ type formatter = { the same purpose as in {!field:fun_name}. - loop identifier, if this is for a loop *) + terminates_clause_name : + A.FunDeclId.id -> fun_name -> int -> LoopId.id option -> string; + (** Generates the name of the measure used to prove/reason about + termination. The generated code uses this clause where needed, + but its body must be defined by the user. Lean only. + + Inputs: + - function id: this is especially useful to identify whether the + function is an assumed function or a local function + - function basename + - the number of loops in the parent function. This is used for + the same purpose as in {!field:fun_name}. + - loop identifier, if this is for a loop + *) + var_basename : StringSet.t -> string option -> ty -> string; (** Generates a variable basename. @@ -285,6 +300,12 @@ type id = the body of those clauses must be defined by the user, in the proper files. *) + | TerminatesClauseId of (A.fun_id * LoopId.id option) + (** The definition which provides the decreases/termination measure. + We insert calls to this clause to prove/reason about termination: + the body of those clauses must be defined by the user, in the + proper files. + *) | TypeId of type_id | StructId of type_id (** We use this when we manipulate the names of the structure @@ -486,6 +507,19 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = | Some lid -> ", loop: " ^ LoopId.to_string lid in "decreases clause for function: " ^ fun_name ^ loop + | TerminatesClauseId (fid, lid) -> + let fun_name = + match fid with + | Regular fid -> + Print.fun_name_to_string (A.FunDeclId.Map.find fid fun_decls).name + | Assumed aid -> A.show_assumed_fun_id aid + in + let loop = + match lid with + | None -> "" + | Some lid -> ", loop: " ^ LoopId.to_string lid + in + "terminates clause for function: " ^ fun_name ^ loop | TypeId id -> "type name: " ^ get_type_name id | StructId id -> "struct constructor of: " ^ get_type_name id | VariantId (id, variant_id) -> @@ -596,6 +630,10 @@ let ctx_get_decreases_clause (def_id : A.FunDeclId.id) (loop_id : LoopId.id option) (ctx : extraction_ctx) : string = ctx_get (DecreasesClauseId (Regular def_id, loop_id)) ctx +let ctx_get_terminates_clause (def_id : A.FunDeclId.id) + (loop_id : LoopId.id option) (ctx : extraction_ctx) : string = + ctx_get (TerminatesClauseId (Regular def_id, loop_id)) ctx + (** Generate a unique type variable name and add it to the context *) let ctx_add_type_var (basename : string) (id : TypeVarId.id) (ctx : extraction_ctx) : extraction_ctx * string = @@ -688,6 +726,14 @@ let ctx_add_decreases_clause (def : fun_decl) (ctx : extraction_ctx) : in ctx_add (DecreasesClauseId (Regular def.def_id, def.loop_id)) name ctx +let ctx_add_terminates_clause (def : fun_decl) (ctx : extraction_ctx) : + extraction_ctx = + let name = + ctx.fmt.terminates_clause_name def.def_id def.basename def.num_loops + def.loop_id + in + ctx_add (TerminatesClauseId (Regular def.def_id, def.loop_id)) name ctx + let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) : extraction_ctx = let name = ctx.fmt.global_name def.name in diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 4ca9eff2..0a1c8f9a 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -563,7 +563,10 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) let extract_decrease decl = let has_decr_clause = has_decreases_clause decl in if has_decr_clause then - Extract.extract_template_decreases_clause ctx.extract_ctx fmt decl + if !Config.backend = Lean then + Extract.extract_termination_and_decreasing ctx.extract_ctx fmt decl + else + Extract.extract_template_decreases_clause ctx.extract_ctx fmt decl in extract_decrease fwd; List.iter extract_decrease loop_fwds) |