summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2022-02-09 16:57:15 +0100
committerSon Ho2022-02-09 16:57:15 +0100
commit3cd24d0b0ecd4a7a71587a5f1479852f40f959ff (patch)
tree9605cea75b0153e04526e79c2604ebb0026b9298
parentdd1a786022b493c10da6f4d6d1c88a41b70e1eb5 (diff)
Implement generation of template decrease clauses
-rw-r--r--src/ExtractToFStar.ml104
-rw-r--r--src/Translate.ml9
2 files changed, 108 insertions, 5 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml
index 95a29c56..a481affd 100644
--- a/src/ExtractToFStar.ml
+++ b/src/ExtractToFStar.ml
@@ -940,6 +940,99 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_close_box fmt ())
| Meta (_, e) -> extract_texpression ctx fmt inside e
+(** A small utility to print the parameters of a function signature.
+
+ We return two contexts:
+ - the context augmented with bindings for the type parameters
+ - the previous context augmented with bindings for the input values
+ *)
+let extract_fun_parameters (ctx : extraction_ctx) (fmt : F.formatter)
+ (def : fun_def) : extraction_ctx * extraction_ctx =
+ (* Add the type parameters - note that we need those bindings only for the
+ * body translation (they are not top-level) *)
+ let ctx, _ = ctx_add_type_params def.signature.type_params ctx in
+ (* Print the parameters - rk.: we should have filtered the functions
+ * with no input parameters *)
+ (* The type parameters *)
+ if def.signature.type_params <> [] then (
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "(";
+ List.iter
+ (fun (p : type_var) ->
+ let pname = ctx_get_type_var p.index ctx in
+ F.pp_print_string fmt pname;
+ F.pp_print_space fmt ())
+ def.signature.type_params;
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "Type0)");
+ (* The input parameters - note that doing this adds bindings to the context *)
+ let ctx_body =
+ List.fold_left
+ (fun ctx (lv : typed_lvalue) ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "(";
+ let ctx = extract_typed_lvalue ctx fmt false lv in
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ extract_ty ctx fmt false lv.ty;
+ F.pp_print_string fmt ")";
+ ctx)
+ ctx def.inputs_lvs
+ in
+ (ctx, ctx_body)
+
+(** Extract a decrease clause function template body.
+
+ In order to help the user, we can generate a template for the functions
+ required by the decreases clauses. We simply generate definitions of
+ the following form in a separate file:
+ ```
+ let f_decrease (t : Type0) (x : t) : nat = admit()
+ ```
+
+ Where the translated functions for `f` look like this:
+ ```
+ let f_fwd (t : Type0) (x : t) : Tot ... (decreases (f_decrease t x)) = ...
+ ```
+ *)
+let extract_template_decrease_clause (ctx : extraction_ctx) (fmt : F.formatter)
+ (def : fun_def) : unit =
+ (* Retrieve the function name *)
+ let def_name = ctx_get_decrease_clause def.def_id ctx in
+ (* Add a break before *)
+ F.pp_print_break fmt 0 0;
+ (* Print a comment to link the extracted type to its original rust definition *)
+ F.pp_print_string fmt
+ ("(** [" ^ Print.name_to_string def.basename ^ "]: decrease clause *)");
+ F.pp_print_space fmt ();
+ (* Open a box for the definition, so that whenever possible it gets printed on
+ * one line - TODO: remove? *)
+ F.pp_open_hvbox fmt 0;
+ (* Open a box for "let FUN_NAME (PARAMS) =" *)
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ (* > "let FUN_NAME" *)
+ F.pp_print_string fmt ("let " ^ def_name);
+ (* Extract the parameters *)
+ let _, _ = extract_fun_parameters ctx fmt def in
+ (* Print the signature *)
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "nat";
+ (* Print the body *)
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "=";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "admit ()";
+ (* Close the box for "let FUN_NAME (PARAMS) =" *)
+ F.pp_close_box fmt ();
+ (* Close the box for the whole definition *)
+ F.pp_close_box fmt ();
+ (* Add breaks to insert new lines between definitions *)
+ F.pp_print_break fmt 0 0
+
(** Extract a function definition.
Note that all the names used for extraction should already have been
@@ -954,9 +1047,9 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter)
(def : fun_def) : unit =
(* Retrieve the function name *)
let def_name = ctx_get_local_function def.def_id def.back_id ctx in
- (* Add the type parameters - note that we need those bindings only for the
- * body translation (they are not top-level) *)
- let ctx, _ = ctx_add_type_params def.signature.type_params ctx in
+ (* (* Add the type parameters - note that we need those bindings only for the
+ * body translation (they are not top-level) *)
+ let ctx, _ = ctx_add_type_params def.signature.type_params ctx in *)
(* Add a break before *)
F.pp_print_break fmt 0 0;
(* Print a comment to link the extracted type to its original rust definition *)
@@ -972,7 +1065,8 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter)
match qualif with Let -> "let" | LetRec -> "let rec" | And -> "and"
in
F.pp_print_string fmt (qualif ^ " " ^ def_name);
- (* Print the parameters - rk.: we should have filtered the functions
+ let ctx, ctx_body = extract_fun_parameters ctx fmt def in
+ (*(* Print the parameters - rk.: we should have filtered the functions
* with no input parameters *)
(* The type parameters *)
if def.signature.type_params <> [] then (
@@ -1001,7 +1095,7 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt ")";
ctx)
ctx def.inputs_lvs
- in
+ in*)
(* Print the return type - note that we have to be careful when
* printing the input values for the decrease clause, because
* it introduces bindings in the context... We thus "forget"
diff --git a/src/Translate.ml b/src/Translate.ml
index 3c12bc90..ec817b50 100644
--- a/src/Translate.ml
+++ b/src/Translate.ml
@@ -347,6 +347,15 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
if keep_fwd then (fwd, fwd) :: back_ls else back_ls)
pure_ls)
in
+ (* Extract the decrease clauses template bodies *)
+ if config.extract_template_decrease_clauses then
+ List.iter
+ (fun (_, (fwd, _)) ->
+ let has_decr_clause = has_decrease_clause fwd in
+ if has_decr_clause then
+ ExtractToFStar.extract_template_decrease_clause ctx.extract_ctx fmt
+ fwd)
+ pure_ls;
(* Extract the function definitions *)
(if config.extract_fun_defs then
(* Check if the functions are mutually recursive - this really works