diff options
author | Son Ho | 2022-02-09 16:57:15 +0100 |
---|---|---|
committer | Son Ho | 2022-02-09 16:57:15 +0100 |
commit | 3cd24d0b0ecd4a7a71587a5f1479852f40f959ff (patch) | |
tree | 9605cea75b0153e04526e79c2604ebb0026b9298 | |
parent | dd1a786022b493c10da6f4d6d1c88a41b70e1eb5 (diff) |
Implement generation of template decrease clauses
-rw-r--r-- | src/ExtractToFStar.ml | 104 | ||||
-rw-r--r-- | src/Translate.ml | 9 |
2 files changed, 108 insertions, 5 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index 95a29c56..a481affd 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -940,6 +940,99 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) F.pp_close_box fmt ()) | Meta (_, e) -> extract_texpression ctx fmt inside e +(** A small utility to print the parameters of a function signature. + + We return two contexts: + - the context augmented with bindings for the type parameters + - the previous context augmented with bindings for the input values + *) +let extract_fun_parameters (ctx : extraction_ctx) (fmt : F.formatter) + (def : fun_def) : extraction_ctx * extraction_ctx = + (* Add the type parameters - note that we need those bindings only for the + * body translation (they are not top-level) *) + let ctx, _ = ctx_add_type_params def.signature.type_params ctx in + (* Print the parameters - rk.: we should have filtered the functions + * with no input parameters *) + (* The type parameters *) + if def.signature.type_params <> [] then ( + F.pp_print_space fmt (); + F.pp_print_string fmt "("; + List.iter + (fun (p : type_var) -> + let pname = ctx_get_type_var p.index ctx in + F.pp_print_string fmt pname; + F.pp_print_space fmt ()) + def.signature.type_params; + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + F.pp_print_string fmt "Type0)"); + (* The input parameters - note that doing this adds bindings to the context *) + let ctx_body = + List.fold_left + (fun ctx (lv : typed_lvalue) -> + F.pp_print_space fmt (); + F.pp_print_string fmt "("; + let ctx = extract_typed_lvalue ctx fmt false lv in + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_ty ctx fmt false lv.ty; + F.pp_print_string fmt ")"; + ctx) + ctx def.inputs_lvs + in + (ctx, ctx_body) + +(** Extract a decrease clause function template body. + + In order to help the user, we can generate a template for the functions + required by the decreases clauses. We simply generate definitions of + the following form in a separate file: + ``` + let f_decrease (t : Type0) (x : t) : nat = admit() + ``` + + Where the translated functions for `f` look like this: + ``` + let f_fwd (t : Type0) (x : t) : Tot ... (decreases (f_decrease t x)) = ... + ``` + *) +let extract_template_decrease_clause (ctx : extraction_ctx) (fmt : F.formatter) + (def : fun_def) : unit = + (* Retrieve the function name *) + let def_name = ctx_get_decrease_clause def.def_id ctx in + (* Add a break before *) + F.pp_print_break fmt 0 0; + (* Print a comment to link the extracted type to its original rust definition *) + F.pp_print_string fmt + ("(** [" ^ Print.name_to_string def.basename ^ "]: decrease clause *)"); + F.pp_print_space fmt (); + (* Open a box for the definition, so that whenever possible it gets printed on + * one line - TODO: remove? *) + F.pp_open_hvbox fmt 0; + (* Open a box for "let FUN_NAME (PARAMS) =" *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* > "let FUN_NAME" *) + F.pp_print_string fmt ("let " ^ def_name); + (* Extract the parameters *) + let _, _ = extract_fun_parameters ctx fmt def in + (* Print the signature *) + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + F.pp_print_string fmt "nat"; + (* Print the body *) + F.pp_print_space fmt (); + F.pp_print_string fmt "="; + F.pp_print_space fmt (); + F.pp_print_string fmt "admit ()"; + (* Close the box for "let FUN_NAME (PARAMS) =" *) + F.pp_close_box fmt (); + (* Close the box for the whole definition *) + F.pp_close_box fmt (); + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0 + (** Extract a function definition. Note that all the names used for extraction should already have been @@ -954,9 +1047,9 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter) (def : fun_def) : unit = (* Retrieve the function name *) let def_name = ctx_get_local_function def.def_id def.back_id ctx in - (* Add the type parameters - note that we need those bindings only for the - * body translation (they are not top-level) *) - let ctx, _ = ctx_add_type_params def.signature.type_params ctx in + (* (* Add the type parameters - note that we need those bindings only for the + * body translation (they are not top-level) *) + let ctx, _ = ctx_add_type_params def.signature.type_params ctx in *) (* Add a break before *) F.pp_print_break fmt 0 0; (* Print a comment to link the extracted type to its original rust definition *) @@ -972,7 +1065,8 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter) match qualif with Let -> "let" | LetRec -> "let rec" | And -> "and" in F.pp_print_string fmt (qualif ^ " " ^ def_name); - (* Print the parameters - rk.: we should have filtered the functions + let ctx, ctx_body = extract_fun_parameters ctx fmt def in + (*(* Print the parameters - rk.: we should have filtered the functions * with no input parameters *) (* The type parameters *) if def.signature.type_params <> [] then ( @@ -1001,7 +1095,7 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt ")"; ctx) ctx def.inputs_lvs - in + in*) (* Print the return type - note that we have to be careful when * printing the input values for the decrease clause, because * it introduces bindings in the context... We thus "forget" diff --git a/src/Translate.ml b/src/Translate.ml index 3c12bc90..ec817b50 100644 --- a/src/Translate.ml +++ b/src/Translate.ml @@ -347,6 +347,15 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) if keep_fwd then (fwd, fwd) :: back_ls else back_ls) pure_ls) in + (* Extract the decrease clauses template bodies *) + if config.extract_template_decrease_clauses then + List.iter + (fun (_, (fwd, _)) -> + let has_decr_clause = has_decrease_clause fwd in + if has_decr_clause then + ExtractToFStar.extract_template_decrease_clause ctx.extract_ctx fmt + fwd) + pure_ls; (* Extract the function definitions *) (if config.extract_fun_defs then (* Check if the functions are mutually recursive - this really works |