summaryrefslogtreecommitdiff
path: root/compiler/Extract.ml
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--compiler/Extract.ml151
1 files changed, 136 insertions, 15 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 7670c753..f45b9b58 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -489,6 +489,16 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
fname ^ lp_suffix ^ suffix
in
+ let terminates_clause_name (_fid : A.FunDeclId.id) (fname : fun_name)
+ (num_loops : int) (loop_id : LoopId.id option) : string =
+ let fname = fun_name_to_snake_case fname in
+ let lp_suffix = default_fun_loop_suffix num_loops loop_id in
+ (* Compute the suffix *)
+ let suffix = "_terminates" in
+ (* Concatenate *)
+ fname ^ lp_suffix ^ suffix
+ in
+
let var_basename (_varset : StringSet.t) (basename : string option) (ty : ty)
: string =
(* If there is a basename, we use it *)
@@ -619,6 +629,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
global_name;
fun_name;
decreases_clause_name;
+ terminates_clause_name;
var_basename;
type_var_basename;
append_index;
@@ -1348,7 +1359,11 @@ let extract_fun_decl_register_names (ctx : extraction_ctx) (keep_fwd : bool)
let (fwd, loop_fwds), back_ls = def in
(* Register the decrease clauses, if necessary *)
let register_decreases ctx def =
- if has_decreases_clause def then ctx_add_decreases_clause def ctx else ctx
+ if has_decreases_clause def then
+ let ctx = ctx_add_decreases_clause def ctx in
+ ctx_add_terminates_clause def ctx
+ else
+ ctx
in
let ctx = List.fold_left register_decreases ctx (fwd :: loop_fwds) in
(* Register the function names *)
@@ -1626,31 +1641,40 @@ and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
in
let is_lean_struct = !backend = Lean && adt_cons.variant_id = None in
if is_lean_struct then
- (* TODO: enclosing curly brace is indented too far to the right *)
(* TODO: when only one or two fields differ, considering using the with
syntax (peephole optimization) *)
let decl_id = match adt_cons.adt_id with | AdtId id -> id | _ -> assert false in
let def_kind = (TypeDeclId.Map.find decl_id ctx.trans_ctx.type_context.type_decls).kind in
let fields = match def_kind with | T.Struct fields -> fields | _ -> assert false in
let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in
- F.pp_open_vbox fmt ctx.indent_incr;
+ F.pp_open_hvbox fmt 0;
+ F.pp_open_hvbox fmt ctx.indent_incr;
F.pp_print_string fmt "{";
F.pp_print_space fmt ();
+ F.pp_open_hvbox fmt ctx.indent_incr;
+ F.pp_open_hvbox fmt 0;
Collections.List.iter_link
(fun () ->
F.pp_print_string fmt ",";
+ F.pp_close_box fmt ();
F.pp_print_space fmt ()
)
(fun ((fid, _), e) ->
+ F.pp_open_hovbox fmt 0;
let f = ctx_get_field adt_cons.adt_id fid ctx in
F.pp_print_string fmt f;
F.pp_print_string fmt " := ";
- extract_texpression ctx fmt true e
+ F.pp_open_hvbox fmt ctx.indent_incr;
+ extract_texpression ctx fmt true e;
+ F.pp_close_box fmt ()
)
(List.combine fields args);
- F.pp_print_space fmt ();
F.pp_close_box fmt ();
- F.pp_print_string fmt "}";
+ F.pp_close_box fmt ();
+ F.pp_close_box fmt ();
+ F.pp_close_box fmt ();
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "}"
else
let use_parentheses = inside && args <> [] in
if use_parentheses then F.pp_print_string fmt "(";
@@ -2015,7 +2039,7 @@ let assert_backend_supports_decreases_clauses () =
*)
let extract_template_decreases_clause (ctx : extraction_ctx) (fmt : F.formatter)
(def : fun_decl) : unit =
- assert_backend_supports_decreases_clauses ();
+ assert (!backend = FStar);
(* Retrieve the function name *)
let def_name = ctx_get_decreases_clause def.def_id def.loop_id ctx in
@@ -2028,16 +2052,14 @@ let extract_template_decreases_clause (ctx : extraction_ctx) (fmt : F.formatter)
* one line *)
F.pp_open_hvbox fmt 0;
(* Add the [unfold] keyword *)
- if !backend = FStar then begin
- F.pp_print_string fmt "unfold";
- F.pp_print_space fmt ();
- end;
+ F.pp_print_string fmt "unfold";
+ F.pp_print_space fmt ();
(* Open a box for "let FUN_NAME (PARAMS) : EFFECT = admit()" *)
F.pp_open_hvbox fmt ctx.indent_incr;
(* Open a box for "let FUN_NAME (PARAMS) : EFFECT =" *)
F.pp_open_hovbox fmt ctx.indent_incr;
(* > "let FUN_NAME" *)
- F.pp_print_string fmt ((if !backend = FStar then "let " else "def ") ^ def_name);
+ F.pp_print_string fmt ("let " ^ def_name);
F.pp_print_space fmt ();
(* Extract the parameters *)
let space = ref true in
@@ -2046,20 +2068,119 @@ let extract_template_decreases_clause (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt ":";
(* Print the signature *)
F.pp_print_space fmt ();
- F.pp_print_string fmt (if !backend = FStar then "nat" else "Nat");
+ F.pp_print_string fmt "nat";
(* Print the "=" *)
F.pp_print_space fmt ();
- F.pp_print_string fmt (if !backend = FStar then "=" else ":=");
+ F.pp_print_string fmt "=";
(* Close the box for "let FUN_NAME (PARAMS) : EFFECT =" *)
F.pp_close_box fmt ();
F.pp_print_space fmt ();
(* Print the "admit ()" *)
- F.pp_print_string fmt (if !backend = FStar then "admit ()" else "sorry");
+ F.pp_print_string fmt "admit ()";
+ (* Close the box for "let FUN_NAME (PARAMS) : EFFECT = admit()" *)
+ F.pp_close_box fmt ();
+ (* Close the box for the whole definition *)
+ F.pp_close_box fmt ();
+ (* Add breaks to insert new lines between definitions *)
+ F.pp_print_break fmt 0 0
+
+(** Extract templates for the termination_by and decreases_by clauses of a
+ recursive function definition.
+
+ For Lean only.
+
+ We extract two commands. The first one is a regular definition for the
+ termination measure (the value derived from the function arguments that
+ decreases over function calls). The second one is a macro definition that
+ defines a proof script (allowed to refer to function arguments) that proves
+ termination.
+*)
+let extract_termination_and_decreasing (ctx: extraction_ctx) (fmt: F.formatter) (def: fun_decl): unit =
+ assert (!backend = Lean);
+
+ (* Retrieve the function name *)
+ let def_name = ctx_get_terminates_clause def.def_id def.loop_id ctx in
+ let def_body = Option.get def.body in
+ (* Add a break before *)
+ F.pp_print_break fmt 0 0;
+ (* Print a comment to link the extracted type to its original rust definition *)
+ extract_comment fmt ("[" ^ Print.fun_name_to_string def.basename ^ "]: termination measure");
+ F.pp_print_space fmt ();
+ (* Open a box for the definition, so that whenever possible it gets printed on
+ * one line *)
+ F.pp_open_hvbox fmt 0;
+ (* Add the [unfold] keyword *)
+ F.pp_print_string fmt "@[simp]";
+ F.pp_print_space fmt ();
+ (* Open a box for "let FUN_NAME (PARAMS) : EFFECT = admit()" *)
+ F.pp_open_hvbox fmt ctx.indent_incr;
+ (* Open a box for "let FUN_NAME (PARAMS) : EFFECT =" *)
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ (* > "let FUN_NAME" *)
+ F.pp_print_string fmt ("def " ^ def_name);
+ F.pp_print_space fmt ();
+ (* Extract the parameters *)
+ let space = ref true in
+ let _, ctx_body = extract_fun_parameters space ctx fmt def in
+ (* Print the ":=" *)
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":=";
+ (* Close the box for "let FUN_NAME (PARAMS) : EFFECT =" *)
+ F.pp_close_box fmt ();
+ F.pp_print_space fmt ();
+ (* Tuple of the arguments *)
+ let vars = List.map (fun (v: var) -> v.id) def_body.inputs in
+ if List.length vars = 1 then
+ F.pp_print_string fmt (ctx_get_var (List.hd vars) ctx_body)
+ else begin
+ F.pp_open_hovbox fmt 0;
+ F.pp_print_string fmt "(";
+ Collections.List.iter_link
+ (fun () ->
+ F.pp_print_string fmt ",";
+ F.pp_print_space fmt ())
+ (fun v -> F.pp_print_string fmt (ctx_get_var v ctx_body))
+ vars;
+ F.pp_print_string fmt ")";
+ F.pp_close_box fmt ();
+ end;
(* Close the box for "let FUN_NAME (PARAMS) : EFFECT = admit()" *)
F.pp_close_box fmt ();
(* Close the box for the whole definition *)
F.pp_close_box fmt ();
(* Add breaks to insert new lines between definitions *)
+ F.pp_print_break fmt 0 0;
+
+ (* Now extract a template for the termination proof *)
+ let def_name = ctx_get_decreases_clause def.def_id def.loop_id ctx in
+ (* syntax <def_name> term ... term : tactic *)
+ F.pp_print_break fmt 0 0;
+ F.pp_open_hvbox fmt 0;
+ F.pp_print_string fmt "syntax \"";
+ F.pp_print_string fmt def_name;
+ F.pp_print_string fmt "\" term+ : tactic";
+ F.pp_print_break fmt 0 0;
+ F.pp_print_break fmt 0 0;
+ (* macro_rules | `(tactic| fact_termination_proof $x) => `(tactic| ( *)
+ F.pp_print_string fmt "macro_rules";
+ F.pp_print_space fmt ();
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ F.pp_open_hovbox fmt 0;
+ F.pp_print_string fmt "| `(tactic| ";
+ F.pp_print_string fmt def_name;
+ F.pp_print_space fmt ();
+ Collections.List.iter_link (F.pp_print_space fmt)
+ (fun v ->
+ F.pp_print_string fmt "$";
+ F.pp_print_string fmt (ctx_get_var v ctx_body))
+ vars;
+ F.pp_print_string fmt ") =>";
+ F.pp_close_box fmt ();
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ F.pp_print_string fmt "`(tactic| sorry)";
+ F.pp_close_box fmt ();
+ F.pp_close_box fmt ();
+ F.pp_close_box fmt ();
F.pp_print_break fmt 0 0
(** Extract a function declaration.