summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorJonathan Protzenko2023-01-30 18:05:21 -0800
committerSon HO2023-06-04 21:44:33 +0200
commit9804a5f28cedc79ac89d3b97ec6addb42752df3d (patch)
tree3549c94a08498578f3cfd145475891f45d4ba422 /compiler
parent1d6742c059cf53e73c9bc66cec7ac1f857830e78 (diff)
Fix some printing bits, proper syntax for terminates and decreases clauses
Diffstat (limited to '')
-rw-r--r--compiler/Extract.ml151
-rw-r--r--compiler/ExtractBase.ml46
-rw-r--r--compiler/Translate.ml5
3 files changed, 186 insertions, 16 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 7670c753..f45b9b58 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -489,6 +489,16 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
fname ^ lp_suffix ^ suffix
in
+ let terminates_clause_name (_fid : A.FunDeclId.id) (fname : fun_name)
+ (num_loops : int) (loop_id : LoopId.id option) : string =
+ let fname = fun_name_to_snake_case fname in
+ let lp_suffix = default_fun_loop_suffix num_loops loop_id in
+ (* Compute the suffix *)
+ let suffix = "_terminates" in
+ (* Concatenate *)
+ fname ^ lp_suffix ^ suffix
+ in
+
let var_basename (_varset : StringSet.t) (basename : string option) (ty : ty)
: string =
(* If there is a basename, we use it *)
@@ -619,6 +629,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
global_name;
fun_name;
decreases_clause_name;
+ terminates_clause_name;
var_basename;
type_var_basename;
append_index;
@@ -1348,7 +1359,11 @@ let extract_fun_decl_register_names (ctx : extraction_ctx) (keep_fwd : bool)
let (fwd, loop_fwds), back_ls = def in
(* Register the decrease clauses, if necessary *)
let register_decreases ctx def =
- if has_decreases_clause def then ctx_add_decreases_clause def ctx else ctx
+ if has_decreases_clause def then
+ let ctx = ctx_add_decreases_clause def ctx in
+ ctx_add_terminates_clause def ctx
+ else
+ ctx
in
let ctx = List.fold_left register_decreases ctx (fwd :: loop_fwds) in
(* Register the function names *)
@@ -1626,31 +1641,40 @@ and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
in
let is_lean_struct = !backend = Lean && adt_cons.variant_id = None in
if is_lean_struct then
- (* TODO: enclosing curly brace is indented too far to the right *)
(* TODO: when only one or two fields differ, considering using the with
syntax (peephole optimization) *)
let decl_id = match adt_cons.adt_id with | AdtId id -> id | _ -> assert false in
let def_kind = (TypeDeclId.Map.find decl_id ctx.trans_ctx.type_context.type_decls).kind in
let fields = match def_kind with | T.Struct fields -> fields | _ -> assert false in
let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in
- F.pp_open_vbox fmt ctx.indent_incr;
+ F.pp_open_hvbox fmt 0;
+ F.pp_open_hvbox fmt ctx.indent_incr;
F.pp_print_string fmt "{";
F.pp_print_space fmt ();
+ F.pp_open_hvbox fmt ctx.indent_incr;
+ F.pp_open_hvbox fmt 0;
Collections.List.iter_link
(fun () ->
F.pp_print_string fmt ",";
+ F.pp_close_box fmt ();
F.pp_print_space fmt ()
)
(fun ((fid, _), e) ->
+ F.pp_open_hovbox fmt 0;
let f = ctx_get_field adt_cons.adt_id fid ctx in
F.pp_print_string fmt f;
F.pp_print_string fmt " := ";
- extract_texpression ctx fmt true e
+ F.pp_open_hvbox fmt ctx.indent_incr;
+ extract_texpression ctx fmt true e;
+ F.pp_close_box fmt ()
)
(List.combine fields args);
- F.pp_print_space fmt ();
F.pp_close_box fmt ();
- F.pp_print_string fmt "}";
+ F.pp_close_box fmt ();
+ F.pp_close_box fmt ();
+ F.pp_close_box fmt ();
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "}"
else
let use_parentheses = inside && args <> [] in
if use_parentheses then F.pp_print_string fmt "(";
@@ -2015,7 +2039,7 @@ let assert_backend_supports_decreases_clauses () =
*)
let extract_template_decreases_clause (ctx : extraction_ctx) (fmt : F.formatter)
(def : fun_decl) : unit =
- assert_backend_supports_decreases_clauses ();
+ assert (!backend = FStar);
(* Retrieve the function name *)
let def_name = ctx_get_decreases_clause def.def_id def.loop_id ctx in
@@ -2028,16 +2052,14 @@ let extract_template_decreases_clause (ctx : extraction_ctx) (fmt : F.formatter)
* one line *)
F.pp_open_hvbox fmt 0;
(* Add the [unfold] keyword *)
- if !backend = FStar then begin
- F.pp_print_string fmt "unfold";
- F.pp_print_space fmt ();
- end;
+ F.pp_print_string fmt "unfold";
+ F.pp_print_space fmt ();
(* Open a box for "let FUN_NAME (PARAMS) : EFFECT = admit()" *)
F.pp_open_hvbox fmt ctx.indent_incr;
(* Open a box for "let FUN_NAME (PARAMS) : EFFECT =" *)
F.pp_open_hovbox fmt ctx.indent_incr;
(* > "let FUN_NAME" *)
- F.pp_print_string fmt ((if !backend = FStar then "let " else "def ") ^ def_name);
+ F.pp_print_string fmt ("let " ^ def_name);
F.pp_print_space fmt ();
(* Extract the parameters *)
let space = ref true in
@@ -2046,20 +2068,119 @@ let extract_template_decreases_clause (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt ":";
(* Print the signature *)
F.pp_print_space fmt ();
- F.pp_print_string fmt (if !backend = FStar then "nat" else "Nat");
+ F.pp_print_string fmt "nat";
(* Print the "=" *)
F.pp_print_space fmt ();
- F.pp_print_string fmt (if !backend = FStar then "=" else ":=");
+ F.pp_print_string fmt "=";
(* Close the box for "let FUN_NAME (PARAMS) : EFFECT =" *)
F.pp_close_box fmt ();
F.pp_print_space fmt ();
(* Print the "admit ()" *)
- F.pp_print_string fmt (if !backend = FStar then "admit ()" else "sorry");
+ F.pp_print_string fmt "admit ()";
+ (* Close the box for "let FUN_NAME (PARAMS) : EFFECT = admit()" *)
+ F.pp_close_box fmt ();
+ (* Close the box for the whole definition *)
+ F.pp_close_box fmt ();
+ (* Add breaks to insert new lines between definitions *)
+ F.pp_print_break fmt 0 0
+
+(** Extract templates for the termination_by and decreases_by clauses of a
+ recursive function definition.
+
+ For Lean only.
+
+ We extract two commands. The first one is a regular definition for the
+ termination measure (the value derived from the function arguments that
+ decreases over function calls). The second one is a macro definition that
+ defines a proof script (allowed to refer to function arguments) that proves
+ termination.
+*)
+let extract_termination_and_decreasing (ctx: extraction_ctx) (fmt: F.formatter) (def: fun_decl): unit =
+ assert (!backend = Lean);
+
+ (* Retrieve the function name *)
+ let def_name = ctx_get_terminates_clause def.def_id def.loop_id ctx in
+ let def_body = Option.get def.body in
+ (* Add a break before *)
+ F.pp_print_break fmt 0 0;
+ (* Print a comment to link the extracted type to its original rust definition *)
+ extract_comment fmt ("[" ^ Print.fun_name_to_string def.basename ^ "]: termination measure");
+ F.pp_print_space fmt ();
+ (* Open a box for the definition, so that whenever possible it gets printed on
+ * one line *)
+ F.pp_open_hvbox fmt 0;
+ (* Add the [unfold] keyword *)
+ F.pp_print_string fmt "@[simp]";
+ F.pp_print_space fmt ();
+ (* Open a box for "let FUN_NAME (PARAMS) : EFFECT = admit()" *)
+ F.pp_open_hvbox fmt ctx.indent_incr;
+ (* Open a box for "let FUN_NAME (PARAMS) : EFFECT =" *)
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ (* > "let FUN_NAME" *)
+ F.pp_print_string fmt ("def " ^ def_name);
+ F.pp_print_space fmt ();
+ (* Extract the parameters *)
+ let space = ref true in
+ let _, ctx_body = extract_fun_parameters space ctx fmt def in
+ (* Print the ":=" *)
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":=";
+ (* Close the box for "let FUN_NAME (PARAMS) : EFFECT =" *)
+ F.pp_close_box fmt ();
+ F.pp_print_space fmt ();
+ (* Tuple of the arguments *)
+ let vars = List.map (fun (v: var) -> v.id) def_body.inputs in
+ if List.length vars = 1 then
+ F.pp_print_string fmt (ctx_get_var (List.hd vars) ctx_body)
+ else begin
+ F.pp_open_hovbox fmt 0;
+ F.pp_print_string fmt "(";
+ Collections.List.iter_link
+ (fun () ->
+ F.pp_print_string fmt ",";
+ F.pp_print_space fmt ())
+ (fun v -> F.pp_print_string fmt (ctx_get_var v ctx_body))
+ vars;
+ F.pp_print_string fmt ")";
+ F.pp_close_box fmt ();
+ end;
(* Close the box for "let FUN_NAME (PARAMS) : EFFECT = admit()" *)
F.pp_close_box fmt ();
(* Close the box for the whole definition *)
F.pp_close_box fmt ();
(* Add breaks to insert new lines between definitions *)
+ F.pp_print_break fmt 0 0;
+
+ (* Now extract a template for the termination proof *)
+ let def_name = ctx_get_decreases_clause def.def_id def.loop_id ctx in
+ (* syntax <def_name> term ... term : tactic *)
+ F.pp_print_break fmt 0 0;
+ F.pp_open_hvbox fmt 0;
+ F.pp_print_string fmt "syntax \"";
+ F.pp_print_string fmt def_name;
+ F.pp_print_string fmt "\" term+ : tactic";
+ F.pp_print_break fmt 0 0;
+ F.pp_print_break fmt 0 0;
+ (* macro_rules | `(tactic| fact_termination_proof $x) => `(tactic| ( *)
+ F.pp_print_string fmt "macro_rules";
+ F.pp_print_space fmt ();
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ F.pp_open_hovbox fmt 0;
+ F.pp_print_string fmt "| `(tactic| ";
+ F.pp_print_string fmt def_name;
+ F.pp_print_space fmt ();
+ Collections.List.iter_link (F.pp_print_space fmt)
+ (fun v ->
+ F.pp_print_string fmt "$";
+ F.pp_print_string fmt (ctx_get_var v ctx_body))
+ vars;
+ F.pp_print_string fmt ") =>";
+ F.pp_close_box fmt ();
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ F.pp_print_string fmt "`(tactic| sorry)";
+ F.pp_close_box fmt ();
+ F.pp_close_box fmt ();
+ F.pp_close_box fmt ();
F.pp_print_break fmt 0 0
(** Extract a function declaration.
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index 152dfc99..77170b5b 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -203,6 +203,21 @@ type formatter = {
the same purpose as in {!field:fun_name}.
- loop identifier, if this is for a loop
*)
+ terminates_clause_name :
+ A.FunDeclId.id -> fun_name -> int -> LoopId.id option -> string;
+ (** Generates the name of the measure used to prove/reason about
+ termination. The generated code uses this clause where needed,
+ but its body must be defined by the user. Lean only.
+
+ Inputs:
+ - function id: this is especially useful to identify whether the
+ function is an assumed function or a local function
+ - function basename
+ - the number of loops in the parent function. This is used for
+ the same purpose as in {!field:fun_name}.
+ - loop identifier, if this is for a loop
+ *)
+
var_basename : StringSet.t -> string option -> ty -> string;
(** Generates a variable basename.
@@ -285,6 +300,12 @@ type id =
the body of those clauses must be defined by the user, in the
proper files.
*)
+ | TerminatesClauseId of (A.fun_id * LoopId.id option)
+ (** The definition which provides the decreases/termination measure.
+ We insert calls to this clause to prove/reason about termination:
+ the body of those clauses must be defined by the user, in the
+ proper files.
+ *)
| TypeId of type_id
| StructId of type_id
(** We use this when we manipulate the names of the structure
@@ -486,6 +507,19 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
| Some lid -> ", loop: " ^ LoopId.to_string lid
in
"decreases clause for function: " ^ fun_name ^ loop
+ | TerminatesClauseId (fid, lid) ->
+ let fun_name =
+ match fid with
+ | Regular fid ->
+ Print.fun_name_to_string (A.FunDeclId.Map.find fid fun_decls).name
+ | Assumed aid -> A.show_assumed_fun_id aid
+ in
+ let loop =
+ match lid with
+ | None -> ""
+ | Some lid -> ", loop: " ^ LoopId.to_string lid
+ in
+ "terminates clause for function: " ^ fun_name ^ loop
| TypeId id -> "type name: " ^ get_type_name id
| StructId id -> "struct constructor of: " ^ get_type_name id
| VariantId (id, variant_id) ->
@@ -596,6 +630,10 @@ let ctx_get_decreases_clause (def_id : A.FunDeclId.id)
(loop_id : LoopId.id option) (ctx : extraction_ctx) : string =
ctx_get (DecreasesClauseId (Regular def_id, loop_id)) ctx
+let ctx_get_terminates_clause (def_id : A.FunDeclId.id)
+ (loop_id : LoopId.id option) (ctx : extraction_ctx) : string =
+ ctx_get (TerminatesClauseId (Regular def_id, loop_id)) ctx
+
(** Generate a unique type variable name and add it to the context *)
let ctx_add_type_var (basename : string) (id : TypeVarId.id)
(ctx : extraction_ctx) : extraction_ctx * string =
@@ -688,6 +726,14 @@ let ctx_add_decreases_clause (def : fun_decl) (ctx : extraction_ctx) :
in
ctx_add (DecreasesClauseId (Regular def.def_id, def.loop_id)) name ctx
+let ctx_add_terminates_clause (def : fun_decl) (ctx : extraction_ctx) :
+ extraction_ctx =
+ let name =
+ ctx.fmt.terminates_clause_name def.def_id def.basename def.num_loops
+ def.loop_id
+ in
+ ctx_add (TerminatesClauseId (Regular def.def_id, def.loop_id)) name ctx
+
let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) :
extraction_ctx =
let name = ctx.fmt.global_name def.name in
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 4ca9eff2..0a1c8f9a 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -563,7 +563,10 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
let extract_decrease decl =
let has_decr_clause = has_decreases_clause decl in
if has_decr_clause then
- Extract.extract_template_decreases_clause ctx.extract_ctx fmt decl
+ if !Config.backend = Lean then
+ Extract.extract_termination_and_decreasing ctx.extract_ctx fmt decl
+ else
+ Extract.extract_template_decreases_clause ctx.extract_ctx fmt decl
in
extract_decrease fwd;
List.iter extract_decrease loop_fwds)