summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSon Ho2022-02-04 12:17:24 +0100
committerSon Ho2022-02-04 12:17:24 +0100
commit6ae85370a6d385e6824753f08ac593d22d6fc958 (patch)
tree3f08defbb20ce5d56d5136f249a2960294159558 /src
parent1f4e6c1dbf32bbb58288b1b96ede898f36284977 (diff)
Add generation of unit tests for the synthesized functions
Diffstat (limited to 'src')
-rw-r--r--src/ExtractToFStar.ml51
-rw-r--r--src/Translate.ml57
-rw-r--r--src/main.ml3
3 files changed, 86 insertions, 25 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml
index 919a5b05..7e4a11fe 100644
--- a/src/ExtractToFStar.ml
+++ b/src/ExtractToFStar.ml
@@ -2,6 +2,7 @@
open Errors
open Pure
+open PureUtils
open TranslateCore
open PureToExtract
open StringUtils
@@ -85,6 +86,7 @@ let fstar_keywords =
"match";
"with";
"assert";
+ "assert_norm";
"Type0";
"unit";
"not";
@@ -966,3 +968,52 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_close_box fmt ();
(* Add breaks to insert new lines between definitions *)
F.pp_print_break fmt 0 0
+
+(** Extract a unit test, if the function is a unit function (takes no
+ parameters, returns unit).
+
+ A unit test simply checks that the function normalizes to `Return ()`:
+ ```
+ let _ = assert_norm (FUNCTION () = Return ())
+ ```
+ *)
+let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
+ (def : fun_def) : unit =
+ (* We only insert unit tests for forward functions *)
+ assert (def.back_id = None);
+ (* Check if this is a unit function *)
+ let sg = def.signature in
+ if
+ sg.type_params = []
+ && sg.inputs = [ unit_ty ]
+ && sg.outputs = [ mk_result_ty unit_ty ]
+ then (
+ (* Add a break before *)
+ F.pp_print_break fmt 0 0;
+ (* Print a comment *)
+ F.pp_print_string fmt
+ ("(** Unit test for [" ^ Print.name_to_string def.basename ^ "] *)");
+ F.pp_print_space fmt ();
+ (* Open a box for the test *)
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ (* Print the test *)
+ F.pp_print_string fmt "let _ =";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "assert_norm";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "(";
+ let fun_name = ctx_get_local_function def.def_id def.back_id ctx in
+ F.pp_print_string fmt fun_name;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "()";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "=";
+ F.pp_print_space fmt ();
+ let success = ctx_get_variant (Assumed Result) result_return_id ctx in
+ F.pp_print_string fmt (success ^ " ())");
+ (* Close the box for the test *)
+ F.pp_close_box fmt ();
+ (* Add a break after *)
+ F.pp_print_break fmt 0 0)
+ else (* Do nothing *)
+ ()
diff --git a/src/Translate.ml b/src/Translate.ml
index cca131e6..e79a47c4 100644
--- a/src/Translate.ml
+++ b/src/Translate.ml
@@ -259,9 +259,17 @@ let translate_module_to_pure (config : C.partial_config) (m : M.cfim_module) :
(* Return *)
(trans_ctx, type_defs, pure_translations)
-(** Translate a module and write the synthesized code to an output file *)
+(** Translate a module and write the synthesized code to an output file.
+
+ [test_unit_functions]: if true, insert tests in the generated files to
+ check that the unit functions normalize to `Success _`. For instance,
+ in F* it generates code like this:
+ ```
+ let _ = assert_norm (FUNCTION () = Success ())
+ ```
+ *)
let translate_module (filename : string) (config : C.partial_config)
- (m : M.cfim_module) : unit =
+ (test_unit_functions : bool) (m : M.cfim_module) : unit =
(* Translate the module to the pure AST *)
let trans_ctx, trans_types, trans_funs = translate_module_to_pure config m in
@@ -350,11 +358,17 @@ let translate_module (filename : string) (config : C.partial_config)
let def = Pure.TypeDefId.Map.find id trans_types in
ExtractToFStar.extract_type_def extract_ctx fmt qualif def
in
+
(* In case of (non-mutually) recursive functions, we use a simple procedure to
* check if the forward and backward functions are mutually recursive.
*)
- let export_functions (is_rec : bool) (is_mut_rec : bool)
- (fls : Pure.fun_def list) : unit =
+ let export_functions (is_rec : bool) (pure_ls : pure_fun_translation list) :
+ unit =
+ (* Generate the function definitions *)
+ let is_mut_rec = is_rec && pure_ls <> [] in
+ let fls =
+ List.concat (List.map (fun (fwd, back_ls) -> fwd :: back_ls) pure_ls)
+ in
List.iteri
(fun i def ->
let qualif =
@@ -364,8 +378,15 @@ let translate_module (filename : string) (config : C.partial_config)
else ExtractToFStar.LetRec
in
ExtractToFStar.extract_fun_def extract_ctx fmt qualif def)
- fls
+ fls;
+ (* Insert unit tests if necessary *)
+ if test_unit_functions then
+ List.iter
+ (fun (fwd, _) ->
+ ExtractToFStar.extract_unit_test_if_unit_fun extract_ctx fmt fwd)
+ pure_ls
in
+
let export_decl (decl : M.declaration_group) : unit =
match decl with
| Type (NonRec id) -> export_type ExtractToFStar.Type id
@@ -378,30 +399,18 @@ let translate_module (filename : string) (config : C.partial_config)
export_type qualif id)
ids
| Fun (NonRec id) ->
- (* Concatenate *)
- let fwd, back_ls = Pure.FunDefId.Map.find id trans_funs in
- let fls = fwd :: back_ls in
- (* Translate *)
- export_functions false false fls
- | Fun (Rec [ id ]) ->
- (* Simply recursive functions *)
- (* Concatenate *)
- let fwd, back_ls = Pure.FunDefId.Map.find id trans_funs in
- let fls = fwd :: back_ls in
- (* Check if mutually rec *)
- let is_mut_rec = not (PureUtils.functions_not_mutually_recursive fls) in
+ (* Lookup *)
+ let pure_fun = Pure.FunDefId.Map.find id trans_funs in
(* Translate *)
- export_functions true is_mut_rec fls
+ export_functions false [ pure_fun ]
| Fun (Rec ids) ->
(* General case of mutually recursive functions *)
- (* Concatenate *)
- let compute_fun_id_list (id : Pure.FunDefId.id) : Pure.fun_def list =
- let fwd, back_ls = Pure.FunDefId.Map.find id trans_funs in
- fwd :: back_ls
+ (* Lookup *)
+ let pure_funs =
+ List.map (fun id -> Pure.FunDefId.Map.find id trans_funs) ids
in
- let fls = List.concat (List.map compute_fun_id_list ids) in
(* Translate *)
- export_functions true true fls
+ export_functions true pure_funs
in
List.iter export_decl m.declarations;
diff --git a/src/main.ml b/src/main.ml
index 1c2b0fe8..bb7b0e06 100644
--- a/src/main.ml
+++ b/src/main.ml
@@ -84,4 +84,5 @@ let () =
I.Test.test_functions_symbolic config synthesize m;
(* Translate the functions *)
- Translate.translate_module !filename config m
+ let test_unit_functions = true in
+ Translate.translate_module !filename config test_unit_functions m