diff options
author | Son Ho | 2022-02-04 12:17:24 +0100 |
---|---|---|
committer | Son Ho | 2022-02-04 12:17:24 +0100 |
commit | 6ae85370a6d385e6824753f08ac593d22d6fc958 (patch) | |
tree | 3f08defbb20ce5d56d5136f249a2960294159558 /src | |
parent | 1f4e6c1dbf32bbb58288b1b96ede898f36284977 (diff) |
Add generation of unit tests for the synthesized functions
Diffstat (limited to 'src')
-rw-r--r-- | src/ExtractToFStar.ml | 51 | ||||
-rw-r--r-- | src/Translate.ml | 57 | ||||
-rw-r--r-- | src/main.ml | 3 |
3 files changed, 86 insertions, 25 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index 919a5b05..7e4a11fe 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -2,6 +2,7 @@ open Errors open Pure +open PureUtils open TranslateCore open PureToExtract open StringUtils @@ -85,6 +86,7 @@ let fstar_keywords = "match"; "with"; "assert"; + "assert_norm"; "Type0"; "unit"; "not"; @@ -966,3 +968,52 @@ let extract_fun_def (ctx : extraction_ctx) (fmt : F.formatter) F.pp_close_box fmt (); (* Add breaks to insert new lines between definitions *) F.pp_print_break fmt 0 0 + +(** Extract a unit test, if the function is a unit function (takes no + parameters, returns unit). + + A unit test simply checks that the function normalizes to `Return ()`: + ``` + let _ = assert_norm (FUNCTION () = Return ()) + ``` + *) +let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) + (def : fun_def) : unit = + (* We only insert unit tests for forward functions *) + assert (def.back_id = None); + (* Check if this is a unit function *) + let sg = def.signature in + if + sg.type_params = [] + && sg.inputs = [ unit_ty ] + && sg.outputs = [ mk_result_ty unit_ty ] + then ( + (* Add a break before *) + F.pp_print_break fmt 0 0; + (* Print a comment *) + F.pp_print_string fmt + ("(** Unit test for [" ^ Print.name_to_string def.basename ^ "] *)"); + F.pp_print_space fmt (); + (* Open a box for the test *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Print the test *) + F.pp_print_string fmt "let _ ="; + F.pp_print_space fmt (); + F.pp_print_string fmt "assert_norm"; + F.pp_print_space fmt (); + F.pp_print_string fmt "("; + let fun_name = ctx_get_local_function def.def_id def.back_id ctx in + F.pp_print_string fmt fun_name; + F.pp_print_space fmt (); + F.pp_print_string fmt "()"; + F.pp_print_space fmt (); + F.pp_print_string fmt "="; + F.pp_print_space fmt (); + let success = ctx_get_variant (Assumed Result) result_return_id ctx in + F.pp_print_string fmt (success ^ " ())"); + (* Close the box for the test *) + F.pp_close_box fmt (); + (* Add a break after *) + F.pp_print_break fmt 0 0) + else (* Do nothing *) + () diff --git a/src/Translate.ml b/src/Translate.ml index cca131e6..e79a47c4 100644 --- a/src/Translate.ml +++ b/src/Translate.ml @@ -259,9 +259,17 @@ let translate_module_to_pure (config : C.partial_config) (m : M.cfim_module) : (* Return *) (trans_ctx, type_defs, pure_translations) -(** Translate a module and write the synthesized code to an output file *) +(** Translate a module and write the synthesized code to an output file. + + [test_unit_functions]: if true, insert tests in the generated files to + check that the unit functions normalize to `Success _`. For instance, + in F* it generates code like this: + ``` + let _ = assert_norm (FUNCTION () = Success ()) + ``` + *) let translate_module (filename : string) (config : C.partial_config) - (m : M.cfim_module) : unit = + (test_unit_functions : bool) (m : M.cfim_module) : unit = (* Translate the module to the pure AST *) let trans_ctx, trans_types, trans_funs = translate_module_to_pure config m in @@ -350,11 +358,17 @@ let translate_module (filename : string) (config : C.partial_config) let def = Pure.TypeDefId.Map.find id trans_types in ExtractToFStar.extract_type_def extract_ctx fmt qualif def in + (* In case of (non-mutually) recursive functions, we use a simple procedure to * check if the forward and backward functions are mutually recursive. *) - let export_functions (is_rec : bool) (is_mut_rec : bool) - (fls : Pure.fun_def list) : unit = + let export_functions (is_rec : bool) (pure_ls : pure_fun_translation list) : + unit = + (* Generate the function definitions *) + let is_mut_rec = is_rec && pure_ls <> [] in + let fls = + List.concat (List.map (fun (fwd, back_ls) -> fwd :: back_ls) pure_ls) + in List.iteri (fun i def -> let qualif = @@ -364,8 +378,15 @@ let translate_module (filename : string) (config : C.partial_config) else ExtractToFStar.LetRec in ExtractToFStar.extract_fun_def extract_ctx fmt qualif def) - fls + fls; + (* Insert unit tests if necessary *) + if test_unit_functions then + List.iter + (fun (fwd, _) -> + ExtractToFStar.extract_unit_test_if_unit_fun extract_ctx fmt fwd) + pure_ls in + let export_decl (decl : M.declaration_group) : unit = match decl with | Type (NonRec id) -> export_type ExtractToFStar.Type id @@ -378,30 +399,18 @@ let translate_module (filename : string) (config : C.partial_config) export_type qualif id) ids | Fun (NonRec id) -> - (* Concatenate *) - let fwd, back_ls = Pure.FunDefId.Map.find id trans_funs in - let fls = fwd :: back_ls in - (* Translate *) - export_functions false false fls - | Fun (Rec [ id ]) -> - (* Simply recursive functions *) - (* Concatenate *) - let fwd, back_ls = Pure.FunDefId.Map.find id trans_funs in - let fls = fwd :: back_ls in - (* Check if mutually rec *) - let is_mut_rec = not (PureUtils.functions_not_mutually_recursive fls) in + (* Lookup *) + let pure_fun = Pure.FunDefId.Map.find id trans_funs in (* Translate *) - export_functions true is_mut_rec fls + export_functions false [ pure_fun ] | Fun (Rec ids) -> (* General case of mutually recursive functions *) - (* Concatenate *) - let compute_fun_id_list (id : Pure.FunDefId.id) : Pure.fun_def list = - let fwd, back_ls = Pure.FunDefId.Map.find id trans_funs in - fwd :: back_ls + (* Lookup *) + let pure_funs = + List.map (fun id -> Pure.FunDefId.Map.find id trans_funs) ids in - let fls = List.concat (List.map compute_fun_id_list ids) in (* Translate *) - export_functions true true fls + export_functions true pure_funs in List.iter export_decl m.declarations; diff --git a/src/main.ml b/src/main.ml index 1c2b0fe8..bb7b0e06 100644 --- a/src/main.ml +++ b/src/main.ml @@ -84,4 +84,5 @@ let () = I.Test.test_functions_symbolic config synthesize m; (* Translate the functions *) - Translate.translate_module !filename config m + let test_unit_functions = true in + Translate.translate_module !filename config test_unit_functions m |