summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSon Ho2022-02-09 16:35:36 +0100
committerSon Ho2022-02-09 16:35:36 +0100
commitdd1a786022b493c10da6f4d6d1c88a41b70e1eb5 (patch)
tree959895c7e52f45635fff8999f9d8dcb0de0782c1 /src
parent056e6af4cf469dc9d72dff5222363edd9b563588 (diff)
Implement the generation of `decreases` clauses in the definition
signatures
Diffstat (limited to '')
-rw-r--r--src/ExtractToFStar.ml68
-rw-r--r--src/Logging.ml7
-rw-r--r--src/PureMicroPasses.ml8
-rw-r--r--src/PureToExtract.ml4
-rw-r--r--src/Translate.ml26
5 files changed, 93 insertions, 20 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml
index 9ad3c94c..95a29c56 100644
--- a/src/ExtractToFStar.ml
+++ b/src/ExtractToFStar.ml
@@ -944,9 +944,14 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter)
Note that all the names used for extraction should already have been
registered.
+
+ We take the definition of the forward translation as parameter (which is
+ equal to the definition to extract, if we extract a forward function) because
+ it is useful for the decrease clause.
*)
let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter)
- (qualif : fun_def_qualif) (def : fun_def) : unit =
+ (qualif : fun_def_qualif) (has_decrease_clause : bool) (fwd_def : fun_def)
+ (def : fun_def) : unit =
(* Retrieve the function name *)
let def_name = ctx_get_local_function def.def_id def.back_id ctx in
(* Add the type parameters - note that we need those bindings only for the
@@ -982,8 +987,8 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt ":";
F.pp_print_space fmt ();
F.pp_print_string fmt "Type0)");
- (* The input parameters - note that doing this adds bindings in the context *)
- let ctx =
+ (* The input parameters - note that doing this adds bindings to the context *)
+ let ctx_body =
List.fold_left
(fun ctx (lv : typed_lvalue) ->
F.pp_print_space fmt ();
@@ -997,7 +1002,11 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter)
ctx)
ctx def.inputs_lvs
in
- (* Print the return type *)
+ (* Print the return type - note that we have to be careful when
+ * printing the input values for the decrease clause, because
+ * it introduces bindings in the context... We thus "forget"
+ * the bindings we introduced above.
+ * TODO: figure out a cleaner way *)
let _ =
F.pp_print_space fmt ();
(* Open a box for the return type *)
@@ -1005,21 +1014,64 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter)
(* Print the return type *)
F.pp_print_string fmt ":";
F.pp_print_space fmt ();
- extract_ty ctx fmt false
+ (* `Tot` *)
+ if has_decrease_clause then (
+ F.pp_print_string fmt "Tot";
+ F.pp_print_space fmt ());
+ extract_ty ctx fmt has_decrease_clause
(Collections.List.to_cons_nil def.signature.outputs);
(* Close the box for the return type *)
- F.pp_close_box fmt ()
+ F.pp_close_box fmt ();
+ (* Print the decrease clause *)
+ if has_decrease_clause then (
+ F.pp_print_space fmt ();
+ (* Open a box for the decrease clause *)
+ F.pp_open_hovbox fmt 0;
+ (* *)
+ F.pp_print_string fmt "(decreases";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "(";
+ (* The name of the decrease clause *)
+ let decr_name = ctx_get_decrease_clause def.def_id ctx in
+ F.pp_print_string fmt decr_name;
+ (* Print the type parameters *)
+ List.iter
+ (fun (p : type_var) ->
+ let pname = ctx_get_type_var p.index ctx in
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt pname)
+ def.signature.type_params;
+ (* Print the input values: we have to be careful here to print
+ * only the input values which are in common with the *forward*
+ * function (the additional input values "given back" to the
+ * backward functions have no influence on termination: we thus
+ * share the decrease clauses between the forward and the backward
+ * functions) *)
+ let inputs_lvs =
+ Collections.List.prefix (List.length fwd_def.inputs_lvs) def.inputs_lvs
+ in
+ let _ =
+ List.fold_left
+ (fun ctx (lv : typed_lvalue) ->
+ F.pp_print_space fmt ();
+ let ctx = extract_typed_lvalue ctx fmt false lv in
+ ctx)
+ ctx inputs_lvs
+ in
+ F.pp_print_string fmt "))";
+ (* Close the box for the decrease clause *)
+ F.pp_close_box fmt ())
in
(* Print the "=" *)
F.pp_print_space fmt ();
F.pp_print_string fmt "=";
(* Close the box for "let FUN_NAME (PARAMS) =" *)
F.pp_close_box fmt ();
- F.pp_print_break fmt 1 ctx.indent_incr;
+ F.pp_print_break fmt 1 ctx_body.indent_incr;
(* Open a box for the body *)
F.pp_open_hvbox fmt 0;
(* Extract the body *)
- let _ = extract_texpression ctx fmt false def.body in
+ let _ = extract_texpression ctx_body fmt false def.body in
(* Close the box for the body *)
F.pp_close_box fmt ();
(* Close the box for the definition *)
diff --git a/src/Logging.ml b/src/Logging.ml
index 55b07ed4..a7313de5 100644
--- a/src/Logging.ml
+++ b/src/Logging.ml
@@ -122,15 +122,16 @@ let style_to_codes s =
| Fg c -> (to_fg_code c, to_fg_code Default)
| Bg c -> (to_bg_code c, to_bg_code Default)
-(** TODO: comes from easy_logging (did not manage to reuse the module directly) *)
+(** TODO: comes from easy_logging (did not manage to reuse the module directly)
+ I made a minor modifications, though. *)
let level_to_color (lvl : L.level) =
match lvl with
| L.Flash -> LMagenta
| Error -> LRed
| Warning -> LYellow
- | Info -> LBlue
+ | Info -> LGreen
| Trace -> Cyan
- | Debug -> Green
+ | Debug -> LBlue
| NoLevel -> Default
(** [format styles str] formats [str] to the given [styles] -
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index dba6b5e8..b4f16e6e 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -617,10 +617,14 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
in
(* Visit the body *)
let body, used_vars = expr_visitor#visit_texpression () def.body in
- (* Visit the parameters *)
+ (* Visit the parameters - TODO: update: we need filter only if the definition
+ * is not recursive (otherwise it might mess up with the decrease clauses).
+ * For now we deactivate the filtering *)
let used_vars = used_vars () in
let inputs_lvs =
- List.map (fun lv -> fst (filter_typed_lvalue used_vars lv)) def.inputs_lvs
+ if false then
+ List.map (fun lv -> fst (filter_typed_lvalue used_vars lv)) def.inputs_lvs
+ else def.inputs_lvs
in
(* Return *)
{ def with body; inputs_lvs }
diff --git a/src/PureToExtract.ml b/src/PureToExtract.ml
index bb010a05..4d29d517 100644
--- a/src/PureToExtract.ml
+++ b/src/PureToExtract.ml
@@ -446,6 +446,10 @@ let ctx_get_variant (def_id : type_id) (variant_id : VariantId.id)
(ctx : extraction_ctx) : string =
ctx_get (VariantId (def_id, variant_id)) ctx
+let ctx_get_decrease_clause (def_id : FunDefId.id) (ctx : extraction_ctx) :
+ string =
+ ctx_get (DecreaseClauseId (A.Local def_id)) ctx
+
(** Generate a unique type variable name and add it to the context *)
let ctx_add_type_var (basename : string) (id : TypeVarId.id)
(ctx : extraction_ctx) : extraction_ctx * string =
diff --git a/src/Translate.ml b/src/Translate.ml
index 9d5d2b75..3c12bc90 100644
--- a/src/Translate.ml
+++ b/src/Translate.ml
@@ -326,18 +326,25 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
ExtractToFStar.extract_type_def ctx.extract_ctx fmt qualif def
in
+ (* Utility to check a function has a decrease clause *)
+ let has_decrease_clause (def : Pure.fun_def) : bool =
+ Pure.FunDefId.Set.mem def.def_id ctx.functions_with_decrease_clause
+ in
+
(* In case of (non-mutually) recursive functions, we use a simple procedure to
* check if the forward and backward functions are mutually recursive.
*)
let export_functions (is_rec : bool)
(pure_ls : (bool * pure_fun_translation) list) : unit =
(* Concatenate the function definitions, filtering the useless forward
- * functions. *)
+ * functions. We also make pairs: (forward function, backward function)
+ * (the forward function contains useful information that we want to keep) *)
let fls =
List.concat
(List.map
(fun (keep_fwd, (fwd, back_ls)) ->
- if keep_fwd then fwd :: back_ls else back_ls)
+ let back_ls = List.map (fun back -> (fwd, back)) back_ls in
+ if keep_fwd then (fwd, fwd) :: back_ls else back_ls)
pure_ls)
in
(* Extract the function definitions *)
@@ -348,19 +355,23 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
let is_mut_rec =
if is_rec then
if List.length pure_ls <= 1 then
- not (PureUtils.functions_not_mutually_recursive fls)
+ not (PureUtils.functions_not_mutually_recursive (List.map fst fls))
else true
else false
in
List.iteri
- (fun i def ->
+ (fun i (fwd_def, def) ->
let qualif =
if not is_rec then ExtractToFStar.Let
else if is_mut_rec then
if i = 0 then ExtractToFStar.LetRec else ExtractToFStar.And
else ExtractToFStar.LetRec
in
- ExtractToFStar.extract_fun_def ctx.extract_ctx fmt qualif def)
+ let has_decr_clause =
+ has_decrease_clause def && config.extract_decrease_clauses
+ in
+ ExtractToFStar.extract_fun_def ctx.extract_ctx fmt qualif
+ has_decr_clause fwd_def def)
fls);
(* Insert unit tests if necessary *)
if config.test_unit_functions then
@@ -520,8 +531,9 @@ let translate_module (filename : string) (dest_dir : string) (config : config)
let gen_config =
{
extract_types = true;
- extract_decrease_clauses = true;
- extract_template_decrease_clauses = false;
+ extract_decrease_clauses = config.extract_decrease_clauses;
+ extract_template_decrease_clauses =
+ config.extract_template_decrease_clauses;
extract_fun_defs = true;
test_unit_functions = config.test_unit_functions;
}