diff options
author | Son Ho | 2022-02-09 16:35:36 +0100 |
---|---|---|
committer | Son Ho | 2022-02-09 16:35:36 +0100 |
commit | dd1a786022b493c10da6f4d6d1c88a41b70e1eb5 (patch) | |
tree | 959895c7e52f45635fff8999f9d8dcb0de0782c1 /src | |
parent | 056e6af4cf469dc9d72dff5222363edd9b563588 (diff) |
Implement the generation of `decreases` clauses in the definition
signatures
Diffstat (limited to 'src')
-rw-r--r-- | src/ExtractToFStar.ml | 68 | ||||
-rw-r--r-- | src/Logging.ml | 7 | ||||
-rw-r--r-- | src/PureMicroPasses.ml | 8 | ||||
-rw-r--r-- | src/PureToExtract.ml | 4 | ||||
-rw-r--r-- | src/Translate.ml | 26 |
5 files changed, 93 insertions, 20 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index 9ad3c94c..95a29c56 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -944,9 +944,14 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) Note that all the names used for extraction should already have been registered. + + We take the definition of the forward translation as parameter (which is + equal to the definition to extract, if we extract a forward function) because + it is useful for the decrease clause. *) let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter) - (qualif : fun_def_qualif) (def : fun_def) : unit = + (qualif : fun_def_qualif) (has_decrease_clause : bool) (fwd_def : fun_def) + (def : fun_def) : unit = (* Retrieve the function name *) let def_name = ctx_get_local_function def.def_id def.back_id ctx in (* Add the type parameters - note that we need those bindings only for the @@ -982,8 +987,8 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt ":"; F.pp_print_space fmt (); F.pp_print_string fmt "Type0)"); - (* The input parameters - note that doing this adds bindings in the context *) - let ctx = + (* The input parameters - note that doing this adds bindings to the context *) + let ctx_body = List.fold_left (fun ctx (lv : typed_lvalue) -> F.pp_print_space fmt (); @@ -997,7 +1002,11 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter) ctx) ctx def.inputs_lvs in - (* Print the return type *) + (* Print the return type - note that we have to be careful when + * printing the input values for the decrease clause, because + * it introduces bindings in the context... We thus "forget" + * the bindings we introduced above. + * TODO: figure out a cleaner way *) let _ = F.pp_print_space fmt (); (* Open a box for the return type *) @@ -1005,21 +1014,64 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter) (* Print the return type *) F.pp_print_string fmt ":"; F.pp_print_space fmt (); - extract_ty ctx fmt false + (* `Tot` *) + if has_decrease_clause then ( + F.pp_print_string fmt "Tot"; + F.pp_print_space fmt ()); + extract_ty ctx fmt has_decrease_clause (Collections.List.to_cons_nil def.signature.outputs); (* Close the box for the return type *) - F.pp_close_box fmt () + F.pp_close_box fmt (); + (* Print the decrease clause *) + if has_decrease_clause then ( + F.pp_print_space fmt (); + (* Open a box for the decrease clause *) + F.pp_open_hovbox fmt 0; + (* *) + F.pp_print_string fmt "(decreases"; + F.pp_print_space fmt (); + F.pp_print_string fmt "("; + (* The name of the decrease clause *) + let decr_name = ctx_get_decrease_clause def.def_id ctx in + F.pp_print_string fmt decr_name; + (* Print the type parameters *) + List.iter + (fun (p : type_var) -> + let pname = ctx_get_type_var p.index ctx in + F.pp_print_space fmt (); + F.pp_print_string fmt pname) + def.signature.type_params; + (* Print the input values: we have to be careful here to print + * only the input values which are in common with the *forward* + * function (the additional input values "given back" to the + * backward functions have no influence on termination: we thus + * share the decrease clauses between the forward and the backward + * functions) *) + let inputs_lvs = + Collections.List.prefix (List.length fwd_def.inputs_lvs) def.inputs_lvs + in + let _ = + List.fold_left + (fun ctx (lv : typed_lvalue) -> + F.pp_print_space fmt (); + let ctx = extract_typed_lvalue ctx fmt false lv in + ctx) + ctx inputs_lvs + in + F.pp_print_string fmt "))"; + (* Close the box for the decrease clause *) + F.pp_close_box fmt ()) in (* Print the "=" *) F.pp_print_space fmt (); F.pp_print_string fmt "="; (* Close the box for "let FUN_NAME (PARAMS) =" *) F.pp_close_box fmt (); - F.pp_print_break fmt 1 ctx.indent_incr; + F.pp_print_break fmt 1 ctx_body.indent_incr; (* Open a box for the body *) F.pp_open_hvbox fmt 0; (* Extract the body *) - let _ = extract_texpression ctx fmt false def.body in + let _ = extract_texpression ctx_body fmt false def.body in (* Close the box for the body *) F.pp_close_box fmt (); (* Close the box for the definition *) diff --git a/src/Logging.ml b/src/Logging.ml index 55b07ed4..a7313de5 100644 --- a/src/Logging.ml +++ b/src/Logging.ml @@ -122,15 +122,16 @@ let style_to_codes s = | Fg c -> (to_fg_code c, to_fg_code Default) | Bg c -> (to_bg_code c, to_bg_code Default) -(** TODO: comes from easy_logging (did not manage to reuse the module directly) *) +(** TODO: comes from easy_logging (did not manage to reuse the module directly) + I made a minor modifications, though. *) let level_to_color (lvl : L.level) = match lvl with | L.Flash -> LMagenta | Error -> LRed | Warning -> LYellow - | Info -> LBlue + | Info -> LGreen | Trace -> Cyan - | Debug -> Green + | Debug -> LBlue | NoLevel -> Default (** [format styles str] formats [str] to the given [styles] - diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index dba6b5e8..b4f16e6e 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -617,10 +617,14 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) in (* Visit the body *) let body, used_vars = expr_visitor#visit_texpression () def.body in - (* Visit the parameters *) + (* Visit the parameters - TODO: update: we need filter only if the definition + * is not recursive (otherwise it might mess up with the decrease clauses). + * For now we deactivate the filtering *) let used_vars = used_vars () in let inputs_lvs = - List.map (fun lv -> fst (filter_typed_lvalue used_vars lv)) def.inputs_lvs + if false then + List.map (fun lv -> fst (filter_typed_lvalue used_vars lv)) def.inputs_lvs + else def.inputs_lvs in (* Return *) { def with body; inputs_lvs } diff --git a/src/PureToExtract.ml b/src/PureToExtract.ml index bb010a05..4d29d517 100644 --- a/src/PureToExtract.ml +++ b/src/PureToExtract.ml @@ -446,6 +446,10 @@ let ctx_get_variant (def_id : type_id) (variant_id : VariantId.id) (ctx : extraction_ctx) : string = ctx_get (VariantId (def_id, variant_id)) ctx +let ctx_get_decrease_clause (def_id : FunDefId.id) (ctx : extraction_ctx) : + string = + ctx_get (DecreaseClauseId (A.Local def_id)) ctx + (** Generate a unique type variable name and add it to the context *) let ctx_add_type_var (basename : string) (id : TypeVarId.id) (ctx : extraction_ctx) : extraction_ctx * string = diff --git a/src/Translate.ml b/src/Translate.ml index 9d5d2b75..3c12bc90 100644 --- a/src/Translate.ml +++ b/src/Translate.ml @@ -326,18 +326,25 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) ExtractToFStar.extract_type_def ctx.extract_ctx fmt qualif def in + (* Utility to check a function has a decrease clause *) + let has_decrease_clause (def : Pure.fun_def) : bool = + Pure.FunDefId.Set.mem def.def_id ctx.functions_with_decrease_clause + in + (* In case of (non-mutually) recursive functions, we use a simple procedure to * check if the forward and backward functions are mutually recursive. *) let export_functions (is_rec : bool) (pure_ls : (bool * pure_fun_translation) list) : unit = (* Concatenate the function definitions, filtering the useless forward - * functions. *) + * functions. We also make pairs: (forward function, backward function) + * (the forward function contains useful information that we want to keep) *) let fls = List.concat (List.map (fun (keep_fwd, (fwd, back_ls)) -> - if keep_fwd then fwd :: back_ls else back_ls) + let back_ls = List.map (fun back -> (fwd, back)) back_ls in + if keep_fwd then (fwd, fwd) :: back_ls else back_ls) pure_ls) in (* Extract the function definitions *) @@ -348,19 +355,23 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) let is_mut_rec = if is_rec then if List.length pure_ls <= 1 then - not (PureUtils.functions_not_mutually_recursive fls) + not (PureUtils.functions_not_mutually_recursive (List.map fst fls)) else true else false in List.iteri - (fun i def -> + (fun i (fwd_def, def) -> let qualif = if not is_rec then ExtractToFStar.Let else if is_mut_rec then if i = 0 then ExtractToFStar.LetRec else ExtractToFStar.And else ExtractToFStar.LetRec in - ExtractToFStar.extract_fun_def ctx.extract_ctx fmt qualif def) + let has_decr_clause = + has_decrease_clause def && config.extract_decrease_clauses + in + ExtractToFStar.extract_fun_def ctx.extract_ctx fmt qualif + has_decr_clause fwd_def def) fls); (* Insert unit tests if necessary *) if config.test_unit_functions then @@ -520,8 +531,9 @@ let translate_module (filename : string) (dest_dir : string) (config : config) let gen_config = { extract_types = true; - extract_decrease_clauses = true; - extract_template_decrease_clauses = false; + extract_decrease_clauses = config.extract_decrease_clauses; + extract_template_decrease_clauses = + config.extract_template_decrease_clauses; extract_fun_defs = true; test_unit_functions = config.test_unit_functions; } |